## Swin Training for Contact-Map Header

Preliminary Parameters

In [1]:
n_frames_per_seq : int = 8
emb_dim : int = 768
predictor_input_size = emb_dim
predictor_output_size = n_frames_per_seq*42

Imports

In [2]:
import sys
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter # pip install tensorboard

Path Handling

In [3]:
h2o_root = '../data/h2o/'
sample_root_train = h2o_root + f'seq_{n_frames_per_seq}_train/'
sample_root_val = h2o_root + f'seq_{n_frames_per_seq}_val/'

### Dataloader

In [4]:
class DataTrain(torch.utils.data.DataLoader):
    def __init__(self, mode, h2o_dir, sample_dir):
        assert mode in ['train', 'val'], f'Invalid mode {mode}. Expected: train, val'
        self.emb_dir = sample_dir + f'emb_swin_{mode}/'
        self.dist_dir = sample_dir + f'cm_{mode}/'
        labels = np.load(h2o_dir + f'action_labels_{mode}.npy')
        self.n_actions = labels.shape[0]

    def __len__(self):
        return self.n_actions
    
    def __getitem__(self, idx):
        emb = np.load(self.emb_dir + f'{idx+1}.npy')
        dist = np.load(self.dist_dir + f'{(idx+1):03d}.npy').flatten()
        return emb.astype(np.float32), dist.astype(np.int32)

### Model

Note: 36 action labels (1-36)

In [40]:
# TODO define MLP predictor model
class ContactPredictor(nn.Module):
    """
    The model class, which defines our classifier.
    """
    def __init__(self, input_size, output_size):
        """
        The constructor of the model.
        """
        super().__init__()
        self.fc1 = nn.Linear(input_size, 500)
        self.fc2 = nn.Linear(500, output_size)

    def forward(self, x):
        """
        The forward pass of the model.

        input: x: torch.Tensor, the input to the model

        output: x: torch.Tensor, the output of the model
        """
        x = self.fc1(x)
        x = self.fc2(F.relu(x))
        x = F.sigmoid(x)
        return x

### Training

Parameters

In [60]:
n_epochs = 200
lr = 0.001
momentum = 0.98
print_every_iters = 230
weight_dir = 'weights/cmpred2/'
os.makedirs(weight_dir, exist_ok=True)

In [58]:
def compute_weighted_loss(labels):
    """
    The loss function of the model.

    input: labels: torch.Tensor, the labels of the data

    output: loss: torch.Tensor, the loss of the model
    """
    assert labels.shape[0] == 1, "Multiple batches not supported"
    contacts = (labels==1).squeeze()
    weight = torch.ones(labels.squeeze().shape)
    weight[contacts] = 50.0
    return nn.BCELoss(weight=weight)

Preparations

In [59]:
# Dataloaders
data_train = DataTrain('train', h2o_root, sample_root_train)
data_val = DataTrain('val', h2o_root, sample_root_val)
train_loader = torch.utils.data.DataLoader(data_train, batch_size=1, shuffle=False, num_workers=4)
val_loader = torch.utils.data.DataLoader(data_val, batch_size=1, shuffle=False, num_workers=4)

# model
model = ContactPredictor(predictor_input_size, predictor_output_size)
model.train()

# define optimizer and loss function
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
# optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# criterion = nn.CrossEntropyLoss()
# criterion = nn.MSELoss()
criterion = nn.BCELoss(reduction='mean')

# tensorboard
tb_writer = SummaryWriter('log_cmpred2')

Training

In [61]:
# initial accuracy
total, correct, running_loss = 0, 0, 0
for (X, y) in val_loader:
    with torch.no_grad():
        criterion = compute_weighted_loss(y)
        y = y.float()
        outputs = model(X)
        loss = criterion(outputs, y)
    running_loss += loss.item()
    
    # make prediction
    outputs = outputs > 0.5 # round to binary field
    outputs = outputs.int()
    correct += (outputs == y).sum().item()
    total += y.numel()

accuracy = correct / total
mean_loss = running_loss / len(val_loader)
print('Validation Loss: %.3f\nValidation Acc: %.3f' % (mean_loss, accuracy))

Validation Loss: 5.987
Validation Acc: 0.492


In [62]:
for epoch in range(n_epochs):
# for epoch in range(n_epochs, n_epochs*2+1):
    print(20*'=', 'Epoch %d' % (epoch + 1), 20*'=')

    running_loss = 0.0
    for i, (X, y) in enumerate(train_loader):
        optimizer.zero_grad()
        criterion = compute_weighted_loss(y)
        y = y.float()

        outputs = model(X) # shape (batch_size, 42*8)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

        # print batch statistics
        running_loss += loss.item()
        if (i + 1) % print_every_iters == 0:
            print(
                f'[Epoch: {epoch + 1} / {n_epochs},'
                f' Iter: {i + 1:5d} / {len(train_loader)}]'
                f' Training loss: {running_loss / (i + 1):.3f}'
            )

    # epoch statistics
    mean_loss = running_loss / len(train_loader)
    tb_writer.add_scalar('Training Loss', mean_loss, epoch)
    print('Training Loss: %.3f' % (mean_loss))


    running_loss = 0.0
    total, correct = 0, 0
    correct_contact, correct_nocontact = 0, 0
    total_contact, total_nocontact = 0, 0
    for (X, y) in val_loader:
        with torch.no_grad():
            criterion = compute_weighted_loss(y)
            y = y.float()
            outputs = model(X)
            loss = criterion(outputs, y)
        running_loss += loss.item()
        
        # make prediction
        # x, y have shape (batch_size, 42*8)
        # print("output shape:", outputs.shape)
        # print("y shape:", y.shape)
        # print("outputs:", outputs)
        # print("y:", y)
        # print("y int:", y.int())
        outputs = outputs > 0.5 # round to binary field
        outputs = outputs.int()

        # compute accuracy for contact and no-contact points
        contact_mask = (y == 1).squeeze()
        no_contact_mask = (y == 0).squeeze()
        total_contact += contact_mask.sum().item()
        total_nocontact += no_contact_mask.sum().item()

        correct_preds = (outputs==y).squeeze()
        correct_contact += (correct_preds[contact_mask]).sum().item()
        correct_nocontact += (correct_preds[no_contact_mask]).sum().item()
        correct += correct_preds.sum().item()
        total += y.numel()
        # print("outputs rounded:", outputs)
        # print("correct:", correct, "total:", total)
        # stop execution here
        # sys.exit()

    accuracy = correct / total
    accuracy_contact = correct_contact / total_contact
    accuracy_nocontact = correct_nocontact / total_nocontact
    mean_loss = running_loss / len(val_loader)
    tb_writer.add_scalar('Validation Loss', mean_loss, epoch)
    tb_writer.add_scalar('Validation Accuracy', accuracy, epoch)
    tb_writer.flush()
    print('Validation Loss: %.3f\nValidation Acc: %.3f' % (mean_loss, accuracy))
    print('Validation Acc Contact: %.3f\nValidation Acc No-Contact: %.3f' % (accuracy_contact, accuracy_nocontact))

    # save weights of current epoch
    torch.save(model.state_dict(), f'{weight_dir}e_{epoch}.pt')

tb_writer.close()
print(f'Successfully trained {n_epochs} epochs.')

[Epoch: 1 / 200, Iter:   230 / 569] Training loss: 3.455
[Epoch: 1 / 200, Iter:   460 / 569] Training loss: 2.819
Training Loss: 2.750
Validation Loss: 2.206
Validation Acc: 0.331
Validation Acc Contact: 0.987
Validation Acc No-Contact: 0.210
[Epoch: 2 / 200, Iter:   230 / 569] Training loss: 2.277
[Epoch: 2 / 200, Iter:   460 / 569] Training loss: 2.170
Training Loss: 2.198
Validation Loss: 2.203
Validation Acc: 0.329
Validation Acc Contact: 0.989
Validation Acc No-Contact: 0.208
[Epoch: 3 / 200, Iter:   230 / 569] Training loss: 2.180
[Epoch: 3 / 200, Iter:   460 / 569] Training loss: 2.084
Training Loss: 2.101
Validation Loss: 2.207
Validation Acc: 0.357
Validation Acc Contact: 0.988
Validation Acc No-Contact: 0.241
[Epoch: 4 / 200, Iter:   230 / 569] Training loss: 2.041
[Epoch: 4 / 200, Iter:   460 / 569] Training loss: 1.965
Training Loss: 1.974
Validation Loss: 2.331
Validation Acc: 0.410
Validation Acc Contact: 0.975
Validation Acc No-Contact: 0.306
[Epoch: 5 / 200, Iter:   230

Load best parameters

In [None]:
model.load_state_dict(torch.load('...'))