## Swin Training for Contact-Map Header

Preliminary Parameters

In [1]:
n_frames_per_seq : int = 8
emb_dim : int = 768
predictor_input_size = emb_dim + n_frames_per_seq*42

Imports

In [2]:
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter # pip install tensorboard

Path Handling

In [3]:
h2o_root = '../data/h2o/'
sample_root_train = h2o_root + f'seq_{n_frames_per_seq}_train/'
sample_root_val = h2o_root + f'seq_{n_frames_per_seq}_val/'

### Dataloader

In [4]:
class DataTrain(torch.utils.data.DataLoader):
    def __init__(self, mode, h2o_dir, sample_dir):
        assert mode in ['train', 'val'], f'Invalid mode {mode}. Expected: train, val'
        self.sample_dir = sample_dir
        self.emb_dir = sample_dir + f'emb_swin_{mode}/'
        # self.dist_dir = sample_dir + f'distances_{mode}/'
        self.dist_dir = sample_dir + f'cm_{mode}/'
        self.labels = np.load(h2o_dir + f'action_labels_{mode}.npy')
        self.n_actions = self.labels.shape[0]

    def __len__(self):
        return self.n_actions
    
    def __getitem__(self, idx):
        emb = np.load(self.emb_dir + f'{idx+1}.npy')
        dist = np.load(self.dist_dir + f'{(idx+1):03d}.npy').flatten()
        return np.hstack([emb, dist]).astype(np.float32), self.labels[idx]-1

### Model

Note: 36 action labels (1-36)

In [5]:
# TODO define MLP predictor model
class ActionPredictor(nn.Module):
    """
    The model class, which defines our classifier.
    """
    def __init__(self, input_size):
        """
        The constructor of the model.
        """
        super().__init__()
        self.fc1 = nn.Linear(input_size, 100)
        self.fc2 = nn.Linear(100, 36)

    def forward(self, x):
        """
        The forward pass of the model.

        input: x: torch.Tensor, the input to the model

        output: x: torch.Tensor, the output of the model
        """
        x = self.fc1(x)
        x = self.fc2(F.relu(x))
        x = F.sigmoid(x)
        return x

### Training

Parameters

In [9]:
n_epochs = 120
lr = 0.001
momentum = 0.98
print_every_iters = 230
weight_dir = 'weights/cmhead/'

Preparations

In [11]:
# Dataloaders
data_train = DataTrain('train', h2o_root, sample_root_train)
data_val = DataTrain('val', h2o_root, sample_root_val)
train_loader = torch.utils.data.DataLoader(data_train, batch_size=2, shuffle=False, num_workers=4)
val_loader = torch.utils.data.DataLoader(data_val, batch_size=2, shuffle=False, num_workers=4)

# model
model = ActionPredictor(predictor_input_size)
model.train()

# define optimizer and loss function
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
# optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
# criterion = nn.MSELoss()

# tensorboard
tb_writer = SummaryWriter('log')

Training

In [13]:
for epoch in range(120, 160):
    print(20*'=', 'Epoch %d' % (epoch + 1), 20*'=')

    running_loss = 0.0
    for i, (X, y) in enumerate(train_loader):
        optimizer.zero_grad()

        outputs = model(X) # shape (1,36) or (2,36)
        # outputs = outputs.squeeze()
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

        # print batch statistics
        running_loss += loss.item()
        if (i + 1) % print_every_iters == 0:
            print(
                f'[Epoch: {epoch + 1} / {n_epochs},'
                f' Iter: {i + 1:5d} / {len(train_loader)}]'
                f' Training loss: {running_loss / (i + 1):.3f}'
            )

    # epoch statistics
    mean_loss = running_loss / len(train_loader)
    tb_writer.add_scalar('Training Loss', mean_loss, epoch)
    print('Training Loss: %.3f' % (mean_loss))


    running_loss = 0.0
    total, correct = 0, 0
    for (X, y) in val_loader:
        with torch.no_grad():
            outputs = model(X)
            # outputs = outputs.squeeze()
            loss = criterion(outputs, y)
        running_loss += loss.item()
        
        # make prediction
        _, predicted = torch.max(outputs.data, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()

    accuracy = correct / total
    mean_loss = running_loss / len(val_loader)
    tb_writer.add_scalar('Validation Loss', mean_loss, epoch)
    tb_writer.add_scalar('Validation Accuracy', accuracy, epoch)
    tb_writer.flush()
    print('Validation Loss: %.3f\nValidation Acc: %.3f' % (mean_loss, accuracy))

    # save weights of current epoch
    torch.save(model.state_dict(), f'{weight_dir}e_{epoch}.pt')

tb_writer.close()
print(f'Successfully trained {n_epochs} epochs.')

[Epoch: 121 / 120, Iter:   230 / 285] Training loss: 2.640
Training Loss: 2.639
Validation Loss: 2.894
Validation Acc: 0.795
[Epoch: 122 / 120, Iter:   230 / 285] Training loss: 2.640
Training Loss: 2.639
Validation Loss: 2.894
Validation Acc: 0.795
[Epoch: 123 / 120, Iter:   230 / 285] Training loss: 2.640
Training Loss: 2.639
Validation Loss: 2.894
Validation Acc: 0.795
[Epoch: 124 / 120, Iter:   230 / 285] Training loss: 2.640
Training Loss: 2.639
Validation Loss: 2.894
Validation Acc: 0.795
[Epoch: 125 / 120, Iter:   230 / 285] Training loss: 2.639
Training Loss: 2.639
Validation Loss: 2.894
Validation Acc: 0.795
[Epoch: 126 / 120, Iter:   230 / 285] Training loss: 2.639
Training Loss: 2.639
Validation Loss: 2.894
Validation Acc: 0.795
[Epoch: 127 / 120, Iter:   230 / 285] Training loss: 2.639
Training Loss: 2.639
Validation Loss: 2.894
Validation Acc: 0.795
[Epoch: 128 / 120, Iter:   230 / 285] Training loss: 2.639
Training Loss: 2.639
Validation Loss: 2.894
Validation Acc: 0.795


Load best parameters

In [None]:
model.load_state_dict(torch.load('...'))