In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from torch.utils.data import DataLoader
from model import *

from utils.dataset import *
import torch.nn as nn
from pathlib import Path
import utils.transforms  as T
import os
import torchvision

In [3]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [4]:
def DebugImages(running_count, images):
    # denormalize imagenet
    images = images * torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).cuda()
    images = images + torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).cuda()

    # save image
    torchvision.utils.save_image(images, f"debug_images/{running_count}.png")


In [5]:
image_count = 5
latent_size = 512
is_dynamic = True
sample_count = None
aggregation_simple = False
embeddings_file_path = "deepsdf_latent_codes/deepsdf_generalization_final_512v2/latent_best.ckpt"

train_dataset = CustomDataset(image_folder_paths= "D:/TUM/ML43D_Project/TrainDataSet",
                            embeddings_file_path=embeddings_file_path, 
                            embedding_name_list_file_path= "data/embeddings/obj_files.json",
                            embedding_size=latent_size,
                            image_count=image_count,
                            sample_count=sample_count,
                            dynamic=is_dynamic,
                            transform=transforms.Compose([
                                T.Resize(224, 224, image_count),
                                #T.RandomCrop(224, image_count),
                                T.GaussianBlur(0.1, 5, image_count),
                                T.Noise(0.1, (-20,-20,-20),(20,20,20), 15, image_count),
                                #T.Rotation(-5, 5, 0.4, image_count),
                                T.SwitchRGB(0.1, image_count),
                                OmniObject.ToTensor(divide255=True, img_count=image_count)
                                ])
                            )
    

val_dataset = CustomDataset(image_folder_paths= "D:/TUM/ML43D_Project/ValidationDataSet",
                            embeddings_file_path=embeddings_file_path, 
                            embedding_name_list_file_path= "data/embeddings/obj_files.json",
                            embedding_size=latent_size,
                            image_count=image_count,
                            sample_count=sample_count,
                            dynamic=is_dynamic,
                            transform=transforms.Compose([
                                T.Resize(224,224, image_count),
                                OmniObject.ToTensor(divide255=True, img_count=image_count)
                                ])
                            )


train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=8)

ToTensor divide255: True
ToTensor divide255: True


In [6]:
model = get_model(latent_size, aggregation_simple=aggregation_simple)

In [7]:
#criterion = nn.L1Loss()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr= 1e-3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model.load_state_dict(torch.load('MAE_loss_checkpoints/model_epoch_290.pth', map_location=device))
# optimizer.load_state_dict(torch.load('MAE_loss_checkpoints/optimizer_epoch_290.pth'))


In [8]:
NUM_EPOCHS = 1000
OUTPUT_PATH = "outputs"

In [9]:
phases = {"TRAIN":train_loader, "VAL":val_loader}

model = model.to(device)


for epoch in range(1, NUM_EPOCHS):

    for phase, data_loader in phases.items():
        model.train() if phase == "TRAIN" else model.eval()

        running_loss = 0.0
        img_model_counter = 0

        for i, batch in enumerate(data_loader, start=1): 
            batch["imgs"] = batch["imgs"].to(device)
            batch["embeddings"] = batch["embeddings"].to(device)

            img_models, embeddings = batch["imgs"], batch["embeddings"]
            img_model_counter += img_models.shape[0]

            N, T = img_models.shape[0], img_models.shape[1]
            img_models = torch.flatten(img_models, start_dim=0, end_dim=1)

            img_models = img_models.to(device)
            embeddings = embeddings.to(device)

            with torch.set_grad_enabled(phase == "TRAIN"):
                if phase == "TRAIN":
                    optimizer.zero_grad()

                outputs = model(img_models) # (N*T, 512)
                #DebugImages(i, img_models)
                outputs = torch.reshape(outputs, (N,T,-1)) # (N, T, 512)
                if aggregation_simple:
                    pred_embeds = torch.mean(outputs, axis=1) 
                else:
                    pred_embeds = model.aggregate(outputs)
                
                loss = criterion(pred_embeds, embeddings)
                if phase == "TRAIN":
                    loss.backward()
                    optimizer.step()


            running_loss += loss.item() * N
            
            print(f"\r[{phase}]  Epoch: {epoch}/{NUM_EPOCHS}  Batch: {i}/{len(data_loader)}  Epoch_Loss: {running_loss/img_model_counter:.8f}  Batch_Loss:{loss.item():.8f}", end="")
            
        print(f"\n[{phase}]  Epoch: {epoch}  Loss: {running_loss/img_model_counter:.8f}")
        
        if phase == "TRAIN" : 
            Path(OUTPUT_PATH).mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), Path(OUTPUT_PATH, f"model_epoch_{epoch}.pth"))
            torch.save(optimizer.state_dict(), Path(OUTPUT_PATH, f"optimizer_epoch_{epoch}.pth"))
            

[TRAIN]  Epoch: 1/1000  Batch: 276/276  Epoch_Loss: 0.00138119  Batch_Loss:0.00139377
[TRAIN]  Epoch: 1  Loss: 0.00138119
[VAL]  Epoch: 1/1000  Batch: 28/28  Epoch_Loss: 0.00131337  Batch_Loss:0.00112267
[VAL]  Epoch: 1  Loss: 0.00131337
[TRAIN]  Epoch: 2/1000  Batch: 276/276  Epoch_Loss: 0.00131538  Batch_Loss:0.00102183
[TRAIN]  Epoch: 2  Loss: 0.00131538
[VAL]  Epoch: 2/1000  Batch: 28/28  Epoch_Loss: 0.00143432  Batch_Loss:0.00155397
[VAL]  Epoch: 2  Loss: 0.00143432
[TRAIN]  Epoch: 3/1000  Batch: 276/276  Epoch_Loss: 0.00133168  Batch_Loss:0.00127095
[TRAIN]  Epoch: 3  Loss: 0.00133168
[VAL]  Epoch: 3/1000  Batch: 28/28  Epoch_Loss: 0.00132889  Batch_Loss:0.00122962
[VAL]  Epoch: 3  Loss: 0.00132889
[TRAIN]  Epoch: 4/1000  Batch: 276/276  Epoch_Loss: 0.00127172  Batch_Loss:0.00112312
[TRAIN]  Epoch: 4  Loss: 0.00127172
[VAL]  Epoch: 4/1000  Batch: 28/28  Epoch_Loss: 0.00129975  Batch_Loss:0.00113544
[VAL]  Epoch: 4  Loss: 0.00129975
[TRAIN]  Epoch: 5/1000  Batch: 276/276  Epoch_Lo

KeyboardInterrupt: 