In [7]:
import torch
import wandb
import argparse
import losses as l
import torch.nn as nn
from data import LoadData 
from simclr import SimCLR
from utils import get_acc
import torch.nn.functional as F

wandb.init(project="associative-vision-models")

class FineTune(nn.Module):
    """Projection module for SimCLR (Pytorch Lightning implementation)"""
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.linear = nn.Linear(2048, 5)
        
    def forward(self, x):
        return self.linear(self.model(x))

def train(model, data_aug, batch_size, epochs, save_as):
    
    if model == "simclr":
        if data_aug == "default":
            data = LoadData([LoadData.default_simclr_train(), LoadData.default_simclr_eval()]).generate_split_dataloader()
        else:
            data = LoadData([LoadData.random_masking_transform()]).generate_split_dataloader()
    else:
        data = LoadData([LoadData.default_transform()]).generate_split_dataloader()
    masked_test_data = LoadData([LoadData.default_transform()], "LFW_masked").generate_dataloader()
        
    if model == "simclr":
        model = torch.jit.script(SimCLR(batch_size, len(data(1, "train")), epochs=epochs))
        optimizer, scheduler = model.configure_optimizer()
        loss_fn = l.NT_Xent(batch_size)
        mode = "SSL"
    elif "simclr" in model.lower():
        params = torch.load(model)
        model = torch.jit.script(SimCLR(32, len(data(1, "train")), epochs=epochs))
        model.load_state_dict(params)
        model = FineTune(model)
        
        optimizer, scheduler = torch.optim.Adam(model.parameters()), False
        loss_fn = nn.CrossEntropyLoss()
        mode = "Finetune"
    else: pass

    train = data(batch_size, "train")
    val = data(batch_size, "val")
    test = data(batch_size, "test")
    masked_test = masked_test_data(batch_size)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    
    for epoch in range(epochs):
        print(f"Epoch: {epoch}")
        loss_l, acc_l = [], []
              
        for (data, labels) in train:
            optimizer.zero_grad()
            if mode == "SSL":
                logits = [model(i.to(device)) for i in data]
                loss = loss_fn(*logits)
            else:
                labels = torch.argmax(labels, dim=1).to(device)
                logits = model(data.to(device))
                loss = loss_fn(logits, labels)
            
            loss.backward()
            optimizer.step()
            loss_l.append(loss.item())
            
            if mode != "SSL":
                acc_l.append(get_acc(logits, labels))
                
        wandb.log({"train_loss" : torch.mean(torch.tensor(loss_l)), "epoch" : epoch})
        if mode != "SSL":
            wandb.log({"train_acc" : torch.mean(torch.tensor(acc_l)), "epoch" : epoch})
            
        vloss_l, vacc_l = [], []
        for (data, labels) in val:
            if mode == "SSL":
                logits = [model(i.to(device)) for i in data]
                loss = loss_fn(*logits)
            else:
                labels = torch.argmax(labels, dim=1).to(device)
                logits = model(data.to(device))
                loss = loss_fn(logits, labels)
                
            vloss_l.append(loss.item())
            if mode != "SSL":
                vacc_l.append(get_acc(logits, labels))
                
        wandb.log({"val_loss" : torch.mean(torch.tensor(vloss_l)), "epoch" : epoch})
        if mode != "SSL":
            wandb.log({"val_acc" : torch.mean(torch.tensor(vacc_l)), "epoch" : epoch})
        if scheduler: scheduler.step()

        if mode != "SSL":
            tloss_l, tacc_l = [], []
            for (data, labels) in test:
                labels = torch.argmax(labels, dim=1).to(device)
                logits = model(data.to(device))
                loss = loss_fn(logits, labels)

                tloss_l.append(loss.item())
                tacc_l.append(get_acc(logits, labels))
                
            wandb.log({"unmasked_test_loss" : torch.mean(torch.tensor(tloss_l))})
            wandb.log({"unmasked_test_loss" : torch.mean(torch.tensor(tacc_l))})
            
            tloss_l, tacc_l = [], []
            for (data, labels) in masked_test:
                labels = torch.argmax(labels, dim=1).to(device)
                logits = model(data.to(device))
                loss = loss_fn(logits, labels)

                tloss_l.append(loss.item())
                tacc_l.append(get_acc(logits, labels))
                
            wandb.log({"unmasked_test_loss" : torch.mean(torch.tensor(tloss_l))})
            wandb.log({"unmasked_test_loss" : torch.mean(torch.tensor(tacc_l))})
    torch.save(model.state_dict(), f"{save_as}.pt")

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

In [8]:
train("SimCLR_Checkpoints/try2.pt", "default", 64, 100, "SimCLR_Checkpoints/fine_tune")

Epoch: 0
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Epoch: 6
Epoch: 7
Epoch: 8
Epoch: 9
Epoch: 10
Epoch: 11
Epoch: 12
Epoch: 13
Epoch: 14
Epoch: 15
Epoch: 16
Epoch: 17
Epoch: 18
Epoch: 19
Epoch: 20
Epoch: 21
Epoch: 22
Epoch: 23
Epoch: 24
Epoch: 25
Epoch: 26
Epoch: 27
Epoch: 28
Epoch: 29
Epoch: 30
Epoch: 31
Epoch: 32
Epoch: 33
Epoch: 34
Epoch: 35
Epoch: 36
Epoch: 37
Epoch: 38
Epoch: 39
Epoch: 40
Epoch: 41
Epoch: 42
Epoch: 43
Epoch: 44
Epoch: 45
Epoch: 46
Epoch: 47
Epoch: 48
Epoch: 49
Epoch: 50
Epoch: 51
Epoch: 52
Epoch: 53
Epoch: 54
Epoch: 55
Epoch: 56
Epoch: 57
Epoch: 58
Epoch: 59
Epoch: 60
Epoch: 61
Epoch: 62
Epoch: 63
Epoch: 64
Epoch: 65
Epoch: 66
Epoch: 67
Epoch: 68
Epoch: 69
Epoch: 70
Epoch: 71
Epoch: 72
Epoch: 73
Epoch: 74
Epoch: 75
Epoch: 76
Epoch: 77
Epoch: 78
Epoch: 79


Thread SenderThread:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/wandb/sdk/internal/internal_util.py", line 51, in run
    self._run()
  File "/opt/conda/lib/python3.7/site-packages/wandb/sdk/internal/internal_util.py", line 102, in _run
    self._process(record)
  File "/opt/conda/lib/python3.7/site-packages/wandb/sdk/internal/internal.py", line 310, in _process
    self._sm.send(record)
  File "/opt/conda/lib/python3.7/site-packages/wandb/sdk/internal/sender.py", line 237, in send
    send_handler(record)
  File "/opt/conda/lib/python3.7/site-packages/wandb/sdk/internal/sender.py", line 830, in send_summary
    self._update_summary()
  File "/opt/conda/lib/python3.7/site-packages/wandb/sdk/internal/sender.py", line 842, in _update_summary
    with open(summary_path, "w") as f:
FileNotFoundError: [Errno 2] No such file or directory: '/home/jupyter/wandb/run-20220930_021451-3h8htrvc/files/wandb-summary.json'
wandb: ERROR Internal wandb error: file 

Exception: The wandb backend process has shutdown