# Abrestnet code file starts here

In [1]:
import argparse
import os
from tqdm import tqdm
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter

import sys
sys.path.append('/home/pasang/all_experiment/FvFold')

import fvfold
from fvfold.models.AbResNet import AbResNet, load_model

from fvfold.util.masking import MASK_VALUE
from fvfold.util.training import check_for_h5_file
# from fvfold.util.util import RawTextArgumentDefaultsHelpFormatter
from fvfold.datasets.H5PairwiseGeometryDataset import H5PairwiseGeometryDataset
from fvfold.preprocess.generate_h5_pairwise_geom_file import antibody_to_h5

_output_names = ['ca_dist', 'cb_dist', 'no_dist', 'omega', 'theta', 'phi']
os.getcwd()

  from .autonotebook import tqdm as notebook_tqdm


'/home/pasang/all_experiment/FvFold/fvfold/models/AbResNet'

In [2]:
import torch

if torch.cuda.is_available():
    print("GPUs available:")
    for i in range(torch.cuda.device_count()):
        device_properties = torch.cuda.get_device_properties(i)
        print(f"{i}: {device_properties.name}, Memory: {device_properties.total_memory / (1024**2):.2f} MB")
else:
    print("No GPUs available.")

GPUs available:
0: NVIDIA GeForce RTX 2080 Ti, Memory: 11019.56 MB
1: NVIDIA GeForce RTX 2080 Ti, Memory: 11019.56 MB
2: NVIDIA GeForce RTX 2080 Ti, Memory: 11019.56 MB
3: NVIDIA GeForce RTX 2080 Ti, Memory: 11019.56 MB
4: NVIDIA GeForce RTX 2080 Ti, Memory: 11019.56 MB
5: NVIDIA GeForce RTX 2080 Ti, Memory: 11019.56 MB
6: NVIDIA RTX A6000, Memory: 48685.31 MB
7: NVIDIA A100 80GB PCIe, Memory: 81069.75 MB


In [3]:
project_path = os.path.abspath(os.path.join(fvfold.__file__, "../.."))
num_blocks1D=3
num_blocks2D=25
dilation_cycle=5
num_bins=37
dropout=0.2
num_blocks2D=25
dilation_cycle=5
num_bins=37
dropout=0.2
epochs=200
save_every=5
batch_size=8
lr=0.01
use_gpu=True
train_split=0.9
h5_file= os.path.join(project_path, 'data/final_antibody.h5')
antibody_database=os.path.join(project_path,'data/antibody_database')
now = str(datetime.now().strftime('%y-%m-%d %H:%M:%S'))
default_model_path = os.path.join(project_path,'trained_models/model_{}/'.format(now))
output_dir=default_model_path
pretrain_model_file=None
random_seed=0
project_path


'/home/pasang/all_experiment/FvFold'

In [4]:
device_type = 'cuda:7' if torch.cuda.is_available(
    ) and use_gpu else 'cpu'
device = torch.device(device_type)
out_dir = output_dir
current_epoch = 0

In [5]:
class FocalLoss(nn.modules.loss._WeightedLoss):
    def __init__(self,
                    weight=None,
                    gamma=2,
                    reduction='mean',
                    ignore_index=MASK_VALUE):
        super(FocalLoss, self).__init__(weight, reduction=reduction)
        self.gamma = gamma
        self.weight = weight
        self.ignore_index = ignore_index

    def forward(self, input, target):
        ce_loss = F.cross_entropy(input,
                                    target,
                                    reduction=self.reduction,
                                    weight=self.weight,
                                    ignore_index=self.ignore_index)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt)**self.gamma * ce_loss).mean()
        return focal_loss

In [6]:
# properties of resnet and 2d resnet
properties = dict(num_out_bins=num_bins,
                    num_blocks1D=num_blocks1D,
                    num_blocks2D=num_blocks2D,
                    dropout_proportion=dropout,
                    dilation_cycle=dilation_cycle)
print("abresnet properties",properties)

abresnet properties {'num_out_bins': 37, 'num_blocks1D': 3, 'num_blocks2D': 25, 'dropout_proportion': 0.2, 'dilation_cycle': 5}


In [7]:
model = AbResNet(21, **properties)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
properties.update({'lr': lr})
print("properties",properties)
print("model",model)

properties {'num_out_bins': 37, 'num_blocks1D': 3, 'num_blocks2D': 25, 'dropout_proportion': 0.2, 'dilation_cycle': 5, 'lr': 0.01}
model AbResNet(
  (linear): Linear(in_features=1024, out_features=128, bias=True)
  (relu): ReLU()
  (resnet1D): ResNet1D(
    (conv1): Conv1d(21, 32, kernel_size=(17,), stride=(1,), padding=(8,), bias=False)
    (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer0): Sequential(
      (0): ResBlock1D(
        (conv1): Conv1d(32, 32, kernel_size=(17,), stride=(1,), padding=(8,), bias=False)
        (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(32, 32, kernel_size=(17,), stride=(1,), padding=(8,), bias=False)
        (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): ResBlock1D(
        (conv1): Conv1d(32, 32, kernel_size=(17,), stride=(1,), padding=(8,), bias=False)
        (bn1): BatchNorm1d(32

In [8]:
print('Making {} ...'.format(out_dir))
os.mkdir(out_dir)

print(properties)

Making /home/pasang/all_experiment/FvFold/trained_models/model_23-12-26 16:00:10/ ...
{'num_out_bins': 37, 'num_blocks1D': 3, 'num_blocks2D': 25, 'dropout_proportion': 0.2, 'dilation_cycle': 5, 'lr': 0.01}


# Load only train loader not validation for testing purpose

In [9]:
properties.update({'lr': lr})

check_for_h5_file(h5_file, antibody_to_h5, antibody_database)
dataset = H5PairwiseGeometryDataset(h5_file,
                                    num_bins=num_bins,
                                    mask_distant_orientations=True)
train_split_length = int(len(dataset) * train_split)
torch.manual_seed(random_seed)
train_dataset, validation_dataset = data.random_split(
    dataset, [train_split_length,
                len(dataset) - train_split_length])
train_loader = data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    collate_fn=H5PairwiseGeometryDataset.merge_samples_to_minibatch)
validation_loader = data.DataLoader(
    validation_dataset,
    batch_size=batch_size,
    collate_fn=H5PairwiseGeometryDataset.merge_samples_to_minibatch)
len(dataset)

22

In [10]:
criterion = FocalLoss(ignore_index=dataset.mask_fill_value)

lr_modifier = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                            verbose=True)
writer = SummaryWriter(os.path.join(out_dir, 'tensorboard'))




print('Model:\n', model)

Model:
 AbResNet(
  (linear): Linear(in_features=1024, out_features=128, bias=True)
  (relu): ReLU()
  (resnet1D): ResNet1D(
    (conv1): Conv1d(21, 32, kernel_size=(17,), stride=(1,), padding=(8,), bias=False)
    (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer0): Sequential(
      (0): ResBlock1D(
        (conv1): Conv1d(32, 32, kernel_size=(17,), stride=(1,), padding=(8,), bias=False)
        (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(32, 32, kernel_size=(17,), stride=(1,), padding=(8,), bias=False)
        (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): ResBlock1D(
        (conv1): Conv1d(32, 32, kernel_size=(17,), stride=(1,), padding=(8,), bias=False)
        (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(32, 32, kernel_size=(17,), stride=(1,),

In [11]:
model=model
train_loader=train_loader
# validation_loader=validation_loader
optimizer=optimizer
device=device
epochs=epochs
current_epoch=current_epoch
criterion=criterion
lr_modifier=lr_modifier
writer=writer
save_file=os.path.join(out_dir, 'model')
save_every=save_every
properties=properties
properties

{'num_out_bins': 37,
 'num_blocks1D': 3,
 'num_blocks2D': 25,
 'dropout_proportion': 0.2,
 'dilation_cycle': 5,
 'lr': 0.01}

In [12]:
properties = {} if properties is None else properties
print('Using {} as device'.format(str(device).upper()))
model = model.to(device)
print(model)
loss_size = len(_output_names) + 1

Using CUDA:7 as device
AbResNet(
  (linear): Linear(in_features=1024, out_features=128, bias=True)
  (relu): ReLU()
  (resnet1D): ResNet1D(
    (conv1): Conv1d(21, 32, kernel_size=(17,), stride=(1,), padding=(8,), bias=False)
    (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer0): Sequential(
      (0): ResBlock1D(
        (conv1): Conv1d(32, 32, kernel_size=(17,), stride=(1,), padding=(8,), bias=False)
        (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(32, 32, kernel_size=(17,), stride=(1,), padding=(8,), bias=False)
        (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): ResBlock1D(
        (conv1): Conv1d(32, 32, kernel_size=(17,), stride=(1,), padding=(8,), bias=False)
        (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(32, 32, kernel_size=(17,

In [13]:
def train_epoch(model, train_loader, optimizer, device, criterion, loss_size):
    """Trains a model for one epoch"""
    model.train()
    running_losses = torch.zeros(loss_size)
    for inputs,e,max_len,labels in tqdm(train_loader, total=len(train_loader)):
        inputs = inputs.to(device)
        labels = [label.to(device) for label in labels]
        
        emb=torch.stack([x.clone().detach() for x in e])
        tensor_list = [torch.tensor(x, dtype=torch.int32) for x in max_len]
        unpadded_tensor = emb[:, :, :max(tensor_list)]
        

        optimizer.zero_grad()

        def handle_batch():
            outputs = model(inputs,unpadded_tensor.to(device))

            losses = [
                criterion(output, label)
                for output, label in zip(outputs, labels)
            ]
            total_loss = sum(losses)
            losses.append(total_loss)

            total_loss.backward()
            optimizer.step()
            return outputs, torch.Tensor([float(l.item()) for l in losses])

        outputs, batch_loss = handle_batch()
        running_losses += batch_loss

    return running_losses


def validate(model, validation_loader, device, criterion, loss_size):
    """"""
    with torch.no_grad():
        model.eval()
        running_losses = torch.zeros(loss_size)
        for  inputs,e,max_len,labels in tqdm(validation_loader,
                                   total=len(validation_loader)):
            inputs = inputs.to(device)
            labels = [label.to(device) for label in labels]
            emb=torch.stack([x.clone().detach() for x in e])
            tensor_list = [torch.tensor(x, dtype=torch.int32) for x in max_len]
            unpadded_tensor = emb[:, :, :max(tensor_list)]

            def handle_batch():
                outputs = model(inputs,unpadded_tensor.to(device))

                losses = [
                    criterion(output, label)
                    for output, label in zip(outputs, labels)
                ]
                total_loss = sum(losses)
                losses.append(total_loss)

                return outputs, torch.Tensor([float(l.item()) for l in losses])

            outputs, batch_loss = handle_batch()
            running_losses += batch_loss
    return running_losses

In [14]:
use_loss=[]
avg_v_loss=[]
avg_t_loss=[]
early_stop = False
best_val_loss = float('inf')
patience = 40
counter = 0

for epoch in range(current_epoch, epochs):
    train_losses = train_epoch(model, train_loader, optimizer, device, criterion, loss_size)
    avg_train_losses = train_losses / len(train_loader)

    use_loss.append(avg_train_losses[6])
    train_loss_dict = dict(zip(_output_names + ['total'], avg_train_losses.tolist()))
    writer.add_scalars('train_loss', train_loss_dict, global_step=epoch)
    print('\nAverage training loss (epoch {}): {}'.format(epoch, train_loss_dict))

    val_losses = validate(model, validation_loader, device, criterion, loss_size)
    avg_val_losses = val_losses / len(validation_loader)
    val_loss_dict = dict(zip(_output_names + ['total'], avg_val_losses.tolist()))
    writer.add_scalars('validation_loss', val_loss_dict, global_step=epoch)
    print('\nAverage validation loss (epoch {}): {}'.format(epoch, val_loss_dict))

    total_val_loss = val_losses[-1]
    lr_modifier.step(total_val_loss)

    avg_v_loss.append(val_loss_dict['total'])
    avg_t_loss.append(train_loss_dict['total'])

    if val_loss_dict['total'] < best_val_loss:
        best_val_loss = val_loss_dict['total']
        counter = 0
        properties.update({'model_state_dict': model.state_dict()})
        properties.update({'optimizer_state_dict': optimizer.state_dict(), 'train_loss': train_loss_dict, 'val_loss': val_loss_dict, 'epoch': epoch})
        torch.save(properties, save_file + "_rs{}.pt".format(epoch))
    else:
        counter += 1
        if counter >= patience:
            early_stop = True
            break

    if early_stop:
        break
    
    torch.cuda.empty_cache()


properties.update({'model_state_dict': model.state_dict()})
properties.update({
    'optimizer_state_dict': optimizer.state_dict(),
    'train_loss': train_loss_dict,
    'val_loss': val_loss_dict,
    'epoch': epoch
})
torch.save(properties, save_file+".pt")

100%|██████████| 3/3 [00:18<00:00,  6.32s/it]



Average training loss (epoch 0): {'ca_dist': 3.7945549488067627, 'cb_dist': 3.77087140083313, 'no_dist': 3.566787004470825, 'omega': 3.743830919265747, 'theta': 3.485022783279419, 'phi': 3.548717737197876, 'total': 21.9097843170166}


100%|██████████| 1/1 [00:00<00:00,  1.24it/s]



Average validation loss (epoch 0): {'ca_dist': 1562516127744.0, 'cb_dist': 205191823360.0, 'no_dist': 1562351370240.0, 'omega': 539834941440.0, 'theta': 49077133312.0, 'phi': 175577612288.0, 'total': 4094548770816.0}


100%|██████████| 3/3 [00:14<00:00,  4.73s/it]



Average training loss (epoch 1): {'ca_dist': 3.253105878829956, 'cb_dist': 3.0995094776153564, 'no_dist': 3.317845106124878, 'omega': 3.0197160243988037, 'theta': 3.2216269969940186, 'phi': 3.2517940998077393, 'total': 19.163597106933594}


100%|██████████| 1/1 [00:00<00:00,  1.27it/s]



Average validation loss (epoch 1): {'ca_dist': 42.96299743652344, 'cb_dist': 44332.28515625, 'no_dist': 129278.9140625, 'omega': 491.0773010253906, 'theta': 10.39402961730957, 'phi': 22767.755859375, 'total': 196923.375}


100%|██████████| 3/3 [00:14<00:00,  4.75s/it]



Average training loss (epoch 2): {'ca_dist': 2.846576690673828, 'cb_dist': 2.7361552715301514, 'no_dist': 3.191948175430298, 'omega': 2.731806516647339, 'theta': 3.108985185623169, 'phi': 3.073443651199341, 'total': 17.688913345336914}


100%|██████████| 1/1 [00:00<00:00,  1.35it/s]



Average validation loss (epoch 2): {'ca_dist': 3.484076738357544, 'cb_dist': 9.450819969177246, 'no_dist': 528.4789428710938, 'omega': 19.92108154296875, 'theta': 337.87432861328125, 'phi': 181.99420166015625, 'total': 1081.203369140625}


100%|██████████| 3/3 [00:14<00:00,  4.72s/it]



Average training loss (epoch 3): {'ca_dist': 2.6886870861053467, 'cb_dist': 2.526691198348999, 'no_dist': 3.1403071880340576, 'omega': 2.5373518466949463, 'theta': 3.057852029800415, 'phi': 2.972414016723633, 'total': 16.923301696777344}


100%|██████████| 1/1 [00:00<00:00,  1.32it/s]



Average validation loss (epoch 3): {'ca_dist': 45.65190505981445, 'cb_dist': 78.73109436035156, 'no_dist': 13.042495727539062, 'omega': 79.58202362060547, 'theta': 21.55893325805664, 'phi': 3.418403387069702, 'total': 241.98483276367188}


100%|██████████| 3/3 [00:14<00:00,  4.75s/it]



Average training loss (epoch 4): {'ca_dist': 2.617968797683716, 'cb_dist': 2.431241989135742, 'no_dist': 3.0856196880340576, 'omega': 2.4484364986419678, 'theta': 3.015669584274292, 'phi': 2.9273059368133545, 'total': 16.526241302490234}


100%|██████████| 1/1 [00:00<00:00,  1.29it/s]



Average validation loss (epoch 4): {'ca_dist': 3.0745346546173096, 'cb_dist': 3.6844160556793213, 'no_dist': 3.5483922958374023, 'omega': 3.757876396179199, 'theta': 3.3905251026153564, 'phi': 3.4184699058532715, 'total': 20.87421417236328}


100%|██████████| 3/3 [00:14<00:00,  4.72s/it]



Average training loss (epoch 5): {'ca_dist': 2.5262811183929443, 'cb_dist': 2.3530848026275635, 'no_dist': 3.019594192504883, 'omega': 2.367924213409424, 'theta': 2.985360860824585, 'phi': 2.8815536499023438, 'total': 16.133798599243164}


100%|██████████| 1/1 [00:00<00:00,  1.57it/s]



Average validation loss (epoch 5): {'ca_dist': 1.9367501735687256, 'cb_dist': 1.9037513732910156, 'no_dist': 2.736081123352051, 'omega': 2.5462734699249268, 'theta': 2.6399238109588623, 'phi': 2.829958200454712, 'total': 14.592738151550293}


100%|██████████| 3/3 [00:14<00:00,  4.76s/it]



Average training loss (epoch 6): {'ca_dist': 2.417114496231079, 'cb_dist': 2.2918741703033447, 'no_dist': 2.9638259410858154, 'omega': 2.310480833053589, 'theta': 2.9521169662475586, 'phi': 2.8577516078948975, 'total': 15.793164253234863}


100%|██████████| 1/1 [00:00<00:00,  1.38it/s]



Average validation loss (epoch 6): {'ca_dist': 2.026782989501953, 'cb_dist': 2.2612438201904297, 'no_dist': 2.9551162719726562, 'omega': 2.257906198501587, 'theta': 2.649580478668213, 'phi': 2.671962261199951, 'total': 14.822591781616211}


100%|██████████| 3/3 [00:14<00:00,  4.73s/it]



Average training loss (epoch 7): {'ca_dist': 2.3290693759918213, 'cb_dist': 2.237582206726074, 'no_dist': 2.907475471496582, 'omega': 2.2402865886688232, 'theta': 2.9256317615509033, 'phi': 2.8253767490386963, 'total': 15.465423583984375}


100%|██████████| 1/1 [00:00<00:00,  1.31it/s]



Average validation loss (epoch 7): {'ca_dist': 1.8983323574066162, 'cb_dist': 2.382483959197998, 'no_dist': 3.082764148712158, 'omega': 2.3018598556518555, 'theta': 2.7709763050079346, 'phi': 2.6944429874420166, 'total': 15.130858421325684}


100%|██████████| 3/3 [00:14<00:00,  4.74s/it]



Average training loss (epoch 8): {'ca_dist': 2.233607053756714, 'cb_dist': 2.1843109130859375, 'no_dist': 2.8309030532836914, 'omega': 2.1793630123138428, 'theta': 2.8953120708465576, 'phi': 2.7857038974761963, 'total': 15.109200477600098}


100%|██████████| 1/1 [00:00<00:00,  1.31it/s]



Average validation loss (epoch 8): {'ca_dist': 1.9897805452346802, 'cb_dist': 2.2980785369873047, 'no_dist': 3.038214683532715, 'omega': 2.275294065475464, 'theta': 2.8349175453186035, 'phi': 2.7661499977111816, 'total': 15.202434539794922}


100%|██████████| 3/3 [00:14<00:00,  4.70s/it]



Average training loss (epoch 9): {'ca_dist': 2.1444685459136963, 'cb_dist': 2.125218629837036, 'no_dist': 2.7486705780029297, 'omega': 2.1202356815338135, 'theta': 2.8394834995269775, 'phi': 2.7482211589813232, 'total': 14.726298332214355}


100%|██████████| 1/1 [00:00<00:00,  1.29it/s]



Average validation loss (epoch 9): {'ca_dist': 1.954177737236023, 'cb_dist': 1.9627525806427002, 'no_dist': 2.7009477615356445, 'omega': 2.101489543914795, 'theta': 2.849574565887451, 'phi': 2.694178342819214, 'total': 14.263121604919434}


100%|██████████| 3/3 [00:14<00:00,  4.73s/it]



Average training loss (epoch 10): {'ca_dist': 2.0480000972747803, 'cb_dist': 2.0712201595306396, 'no_dist': 2.6645984649658203, 'omega': 2.06435227394104, 'theta': 2.7046291828155518, 'phi': 2.7125205993652344, 'total': 14.265320777893066}


100%|██████████| 1/1 [00:00<00:00,  1.38it/s]



Average validation loss (epoch 10): {'ca_dist': 1.9468618631362915, 'cb_dist': 1.8208839893341064, 'no_dist': 2.683018207550049, 'omega': 1.877790093421936, 'theta': 2.8401358127593994, 'phi': 2.622523069381714, 'total': 13.791213035583496}


100%|██████████| 3/3 [00:14<00:00,  4.71s/it]



Average training loss (epoch 11): {'ca_dist': 1.9589948654174805, 'cb_dist': 2.015178918838501, 'no_dist': 2.611527681350708, 'omega': 2.0063021183013916, 'theta': 2.622786283493042, 'phi': 2.670090913772583, 'total': 13.884881019592285}


100%|██████████| 1/1 [00:00<00:00,  1.39it/s]



Average validation loss (epoch 11): {'ca_dist': 1.8933639526367188, 'cb_dist': 1.8347642421722412, 'no_dist': 2.4590935707092285, 'omega': 1.8174734115600586, 'theta': 2.7759037017822266, 'phi': 2.642775535583496, 'total': 13.42337417602539}


100%|██████████| 3/3 [00:14<00:00,  4.75s/it]



Average training loss (epoch 12): {'ca_dist': 1.8945955038070679, 'cb_dist': 1.9523110389709473, 'no_dist': 2.567189931869507, 'omega': 1.9462966918945312, 'theta': 2.562464952468872, 'phi': 2.632776975631714, 'total': 13.555633544921875}


100%|██████████| 1/1 [00:00<00:00,  1.31it/s]



Average validation loss (epoch 12): {'ca_dist': 1.7737727165222168, 'cb_dist': 1.8403033018112183, 'no_dist': 2.2260255813598633, 'omega': 1.7420411109924316, 'theta': 2.6464107036590576, 'phi': 2.6315529346466064, 'total': 12.860106468200684}


100%|██████████| 3/3 [00:14<00:00,  4.69s/it]



Average training loss (epoch 13): {'ca_dist': 1.8317352533340454, 'cb_dist': 1.8903709650039673, 'no_dist': 2.5211479663848877, 'omega': 1.8875738382339478, 'theta': 2.519946813583374, 'phi': 2.581350326538086, 'total': 13.232124328613281}


100%|██████████| 1/1 [00:00<00:00,  1.54it/s]



Average validation loss (epoch 13): {'ca_dist': 1.6948823928833008, 'cb_dist': 1.859534502029419, 'no_dist': 2.176562547683716, 'omega': 1.7215999364852905, 'theta': 2.585479736328125, 'phi': 2.552844285964966, 'total': 12.590903282165527}


100%|██████████| 3/3 [00:12<00:00,  4.00s/it]



Average training loss (epoch 14): {'ca_dist': 1.7708858251571655, 'cb_dist': 1.8139653205871582, 'no_dist': 2.4742887020111084, 'omega': 1.8057283163070679, 'theta': 2.4723997116088867, 'phi': 2.4974496364593506, 'total': 12.834716796875}


100%|██████████| 1/1 [00:00<00:00,  1.58it/s]



Average validation loss (epoch 14): {'ca_dist': 1.6697908639907837, 'cb_dist': 1.7515448331832886, 'no_dist': 2.154567003250122, 'omega': 1.6210558414459229, 'theta': 2.5281622409820557, 'phi': 2.3524632453918457, 'total': 12.077583312988281}


100%|██████████| 3/3 [00:13<00:00,  4.44s/it]



Average training loss (epoch 15): {'ca_dist': 1.7196308374404907, 'cb_dist': 1.7323232889175415, 'no_dist': 2.434652090072632, 'omega': 1.7122770547866821, 'theta': 2.4288556575775146, 'phi': 2.4329135417938232, 'total': 12.460652351379395}


100%|██████████| 1/1 [00:00<00:00,  1.49it/s]



Average validation loss (epoch 15): {'ca_dist': 1.64207923412323, 'cb_dist': 1.5787755250930786, 'no_dist': 2.128288984298706, 'omega': 1.6058752536773682, 'theta': 2.4050285816192627, 'phi': 2.2378628253936768, 'total': 11.597909927368164}


100%|██████████| 3/3 [00:13<00:00,  4.51s/it]



Average training loss (epoch 16): {'ca_dist': 1.6698633432388306, 'cb_dist': 1.66587495803833, 'no_dist': 2.395256280899048, 'omega': 1.64657461643219, 'theta': 2.389798402786255, 'phi': 2.3939743041992188, 'total': 12.16134262084961}


100%|██████████| 1/1 [00:00<00:00,  2.04it/s]



Average validation loss (epoch 16): {'ca_dist': 1.7034519910812378, 'cb_dist': 1.5996814966201782, 'no_dist': 2.281456232070923, 'omega': 1.7527133226394653, 'theta': 2.4574460983276367, 'phi': 2.3046481609344482, 'total': 12.099397659301758}


100%|██████████| 3/3 [00:10<00:00,  3.38s/it]



Average training loss (epoch 17): {'ca_dist': 1.6139761209487915, 'cb_dist': 1.6095091104507446, 'no_dist': 2.3550503253936768, 'omega': 1.5941678285598755, 'theta': 2.3458118438720703, 'phi': 2.3515217304229736, 'total': 11.870037078857422}


100%|██████████| 1/1 [00:00<00:00,  2.37it/s]



Average validation loss (epoch 17): {'ca_dist': 1.6147269010543823, 'cb_dist': 1.5419511795043945, 'no_dist': 2.3330328464508057, 'omega': 1.6964266300201416, 'theta': 2.375319242477417, 'phi': 2.236090660095215, 'total': 11.797547340393066}


100%|██████████| 3/3 [00:09<00:00,  3.25s/it]



Average training loss (epoch 18): {'ca_dist': 1.5680227279663086, 'cb_dist': 1.564890742301941, 'no_dist': 2.3102872371673584, 'omega': 1.5469290018081665, 'theta': 2.2992465496063232, 'phi': 2.3057143688201904, 'total': 11.59508991241455}


100%|██████████| 1/1 [00:00<00:00,  1.35it/s]



Average validation loss (epoch 18): {'ca_dist': 1.4796663522720337, 'cb_dist': 1.4690155982971191, 'no_dist': 2.250156879425049, 'omega': 1.554129958152771, 'theta': 2.271577835083008, 'phi': 2.1413424015045166, 'total': 11.165888786315918}


100%|██████████| 3/3 [00:14<00:00,  4.69s/it]



Average training loss (epoch 19): {'ca_dist': 1.5152873992919922, 'cb_dist': 1.514769196510315, 'no_dist': 2.2634975910186768, 'omega': 1.495141625404358, 'theta': 2.2542293071746826, 'phi': 2.2622509002685547, 'total': 11.305176734924316}


100%|██████████| 1/1 [00:00<00:00,  1.53it/s]



Average validation loss (epoch 19): {'ca_dist': 1.4195315837860107, 'cb_dist': 1.4240238666534424, 'no_dist': 2.1736338138580322, 'omega': 1.4503378868103027, 'theta': 2.2680890560150146, 'phi': 2.1665713787078857, 'total': 10.90218734741211}


100%|██████████| 3/3 [00:11<00:00,  3.74s/it]



Average training loss (epoch 20): {'ca_dist': 1.4652806520462036, 'cb_dist': 1.4680681228637695, 'no_dist': 2.228365421295166, 'omega': 1.4461647272109985, 'theta': 2.211198568344116, 'phi': 2.2210004329681396, 'total': 11.040078163146973}


100%|██████████| 1/1 [00:00<00:00,  2.04it/s]



Average validation loss (epoch 20): {'ca_dist': 1.3456084728240967, 'cb_dist': 1.413948893547058, 'no_dist': 2.2504379749298096, 'omega': 1.380160927772522, 'theta': 2.234394073486328, 'phi': 2.1801810264587402, 'total': 10.804731369018555}


100%|██████████| 3/3 [00:07<00:00,  2.53s/it]



Average training loss (epoch 21): {'ca_dist': 1.4138332605361938, 'cb_dist': 1.421836256980896, 'no_dist': 2.186722755432129, 'omega': 1.397927165031433, 'theta': 2.166116714477539, 'phi': 2.1788299083709717, 'total': 10.765266418457031}


100%|██████████| 1/1 [00:00<00:00,  1.72it/s]



Average validation loss (epoch 21): {'ca_dist': 1.268751621246338, 'cb_dist': 1.4464361667633057, 'no_dist': 2.146299123764038, 'omega': 1.326584815979004, 'theta': 2.199134111404419, 'phi': 2.173083782196045, 'total': 10.56028938293457}


100%|██████████| 3/3 [00:09<00:00,  3.02s/it]



Average training loss (epoch 22): {'ca_dist': 1.3738843202590942, 'cb_dist': 1.374110221862793, 'no_dist': 2.1461665630340576, 'omega': 1.3516124486923218, 'theta': 2.125725030899048, 'phi': 2.136183738708496, 'total': 10.507682800292969}


100%|██████████| 1/1 [00:00<00:00,  1.57it/s]



Average validation loss (epoch 22): {'ca_dist': 1.212906002998352, 'cb_dist': 1.5048997402191162, 'no_dist': 2.0133755207061768, 'omega': 1.2997174263000488, 'theta': 2.173523426055908, 'phi': 2.1309609413146973, 'total': 10.335382461547852}


100%|██████████| 3/3 [00:12<00:00,  4.30s/it]



Average training loss (epoch 23): {'ca_dist': 1.331377387046814, 'cb_dist': 1.3399955034255981, 'no_dist': 2.1060917377471924, 'omega': 1.31578528881073, 'theta': 2.082857370376587, 'phi': 2.097623825073242, 'total': 10.273731231689453}


100%|██████████| 1/1 [00:00<00:00,  3.14it/s]



Average validation loss (epoch 23): {'ca_dist': 1.1969999074935913, 'cb_dist': 1.5740227699279785, 'no_dist': 1.9181768894195557, 'omega': 1.2898437976837158, 'theta': 2.1420083045959473, 'phi': 2.0923385620117188, 'total': 10.213390350341797}


100%|██████████| 3/3 [00:07<00:00,  2.66s/it]



Average training loss (epoch 24): {'ca_dist': 1.2994036674499512, 'cb_dist': 1.3041633367538452, 'no_dist': 2.0690016746520996, 'omega': 1.2811778783798218, 'theta': 2.045006275177002, 'phi': 2.0588009357452393, 'total': 10.057554244995117}


100%|██████████| 1/1 [00:00<00:00,  1.13it/s]



Average validation loss (epoch 24): {'ca_dist': 1.1692897081375122, 'cb_dist': 1.573205828666687, 'no_dist': 1.8018957376480103, 'omega': 1.2140069007873535, 'theta': 2.0696465969085693, 'phi': 1.9836407899856567, 'total': 9.811685562133789}


  0%|          | 0/3 [00:02<?, ?it/s]


KeyboardInterrupt: 

In [None]:

#t=0.42 v=0.56
import matplotlib.pyplot as plt
plt.plot(use_loss)
plt.ylabel('Loss $L(y,\hat{y};a)$')
plt.xlabel('Epochs')
print(use_loss)

In [None]:
import matplotlib.pyplot as plt
plt.plot(avg_t_loss)
print(type(avg_t_loss))
plt.ylabel('Loss $L(y,\hat{y};a)$')
print(avg_t_loss)
plt.xlabel('Epochs')
# print(avg_t_loss)

In [None]:
import matplotlib.pyplot as plt
plt.plot(avg_v_loss)
plt.ylabel('Loss $L(y,\hat{y};a)$')
plt.xlabel('Epochs')
print(avg_v_loss)

In [None]:
plt.figure(figsize=(10,5))
plt.title("Training and Validation Loss")
plt.plot(avg_v_loss,label="val")
plt.plot(avg_t_loss,label="train")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()