In [1]:
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import torchvision.transforms as transforms
from efficientnet_pytorch import EfficientNet
from glob import glob
import pickle
import nltk
from torch.nn.utils.rnn import pack_padded_sequence
import timm
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.preprocessing import MinMaxScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.utils.rnn import pad_packed_sequence
from efficientnet_pytorch_3d import EfficientNet3D

nltk.download('punkt')
nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('pun`kt_tab')

device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')

[nltk_data] Downloading package punkt to /home/gil/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /home/gil/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/gil/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Error loading pun`kt_tab: Package 'pun`kt_tab' not found
[nltk_data]     in index


In [31]:
params={'image_size':512,
        'lr':1e-4,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':5,
        'epochs':300,
        'image_count':50,
        'data_path':'../data/NIA7/',
        'train_csv':'fold_5_train.csv',
        'val_csv':'fold_5_val.csv',
        'vocab_path':'../data/vocab.pkl',
        'embed_size':300,
        'hidden_size':256,
        'num_layers':1,}

In [3]:
class CustomDataset(Dataset):
    def __init__(self, data_path, image_count, image_size, csv, class_dataset, vocab, transform=None):
        self.root = data_path
        self.image_count = image_count
        self.df = pd.read_csv(data_path + csv)
        self.class_dataset = class_dataset
        self.vocab = vocab
        self.transform = transform
        self.image_size = image_size

        # 'age', 'height', 'weight' 컬럼을 MinMaxScaler로 스케일링
        self.df['sex'] = self.df['sex'].map({'M': 0, 'F': 1})
        self.scaler = MinMaxScaler()
        self.df[['age', 'sex', 'height', 'weight']] = self.scaler.fit_transform(self.df[['age', 'sex', 'height', 'weight']])
        
        # 모든 데이터를 메모리에 캐싱
        self.cached_data = []
        for index, row in tqdm(self.df.iterrows(), total=len(self.df), desc="Caching data"):
            patient_id = row['ID']
            
            # CT 이미지(3D 데이터) 로드 및 캐싱
            ct_image_paths = glob(self.root + f'{patient_id}/02. CT/*.png')
            ct_images = torch.zeros(self.image_count, 1, self.image_size, self.image_size)
            ct_image_index = torch.randint(low=0, high=len(ct_image_paths), size=(self.image_count,))
            for count, idx in enumerate(ct_image_index):
                ct_image = Image.open(ct_image_paths[idx]).convert('L')
                if self.transform is not None:
                    ct_image = self.transform(ct_image)
                ct_images[count] = ct_image

            ct_images = ct_images.permute(1, 0, 2, 3)  # 형태: (1, image_count, image_size, image_size)

            # CR 이미지(2D 데이터) 로드 및 캐싱
            cr_image_path = glob(self.root + f'{patient_id}/01. CR/*.png')[0]
            cr_image = Image.open(cr_image_path).convert('L')
            if self.transform is not None:
                cr_image = self.transform(cr_image)

            # 캡션을 토큰 ID로 변환하여 캐싱
            caption = row['text']
            tokens = nltk.tokenize.word_tokenize(str(caption).lower())
            caption_tokens = [vocab('<start>')] + [vocab(token) for token in tokens] + [vocab('<end>')]
            target = torch.Tensor(caption_tokens)

            # 'age', 'height', 'weight' 값 캐싱
            sex = torch.tensor(row['sex'], dtype=torch.float32)
            age = torch.tensor(row['age'], dtype=torch.float32)
            height = torch.tensor(row['height'], dtype=torch.float32)
            weight = torch.tensor(row['weight'], dtype=torch.float32)

            # 캐싱된 데이터 저장
            self.cached_data.append((ct_images, cr_image, target, sex, age, height, weight))

    def __getitem__(self, index):
        # 캐싱된 데이터에서 인덱스로 직접 접근
        return self.cached_data[index]

    def __len__(self):
        return len(self.df)

class Vocabulary(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

def collate_fn(data):
    # Sort data by caption length (descending order).
    data.sort(key=lambda x: len(x[2]), reverse=True)
    ct_images, cr_images, captions, sexs, ages, heights, weights = zip(*data)

    # Merge images
    ct_images = torch.stack(ct_images, 0)
    cr_images = torch.stack(cr_images, 0)

    # Merge captions
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]

    # Stack ages and heights
    sexs = torch.stack(sexs, 0)
    ages = torch.stack(ages, 0)
    heights = torch.stack(heights, 0)
    weights = torch.stack(weights, 0)

    return ct_images, cr_images, targets, lengths, sexs, ages, heights, weights

def idx2word(vocab, indices):
    sentence = []
    for i in range(params['batch_size']):
        indices[i].cpu().numpy()
    
    for index in indices:
        word = vocab.idx2word[index]
        sentence.append(word)
    return sentence

In [23]:
# ======= Encoder - 2D Feature Extractor (CNN) =======
class FeatureExtractor2D(nn.Module):
    """2D Feature extractor block for X-ray images (simple CNN version)"""
    def __init__(self, dropout_rate=0.3):
        super(FeatureExtractor2D, self).__init__()
        self.conv2d_1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2d_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv2d_3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  # 2x downsampling
        self.pool = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(dropout_rate)  # Dropout layer 추가

    def forward(self, inputs):
        x = self.conv2d_1(inputs)
        x = F.relu(x)
        x = self.pool(x)
        x = self.dropout(x)  # Apply dropout

        x = self.conv2d_2(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.dropout(x)  # Apply dropout

        x = self.conv2d_3(x)
        x = F.relu(x)
        x = self.pool(x)
        features = self.flatten(x)
        features = self.dropout(features)  # Apply dropout at the end
        return features
    

# ======= Encoder - 3D Feature Extractor =======
class FeatureExtractor3D(nn.Module):
    """3D Feature extractor block for CT images"""
    def __init__(self, dropout_rate=0.3):
        super(FeatureExtractor3D, self).__init__()
        self.conv3d_1 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.conv3d_2 = nn.Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.conv3d_3 = nn.Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
        self.pool = nn.MaxPool3d(2)
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(dropout_rate)  # Dropout layer 추가

    def forward(self, inputs):
        x = self.conv3d_1(inputs)
        x = F.relu(x)
        x = self.pool(x)
        x = self.dropout(x)  # Apply dropout

        x = self.conv3d_2(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.dropout(x)  # Apply dropout

        x = self.conv3d_3(x)
        x = F.relu(x)
        x = self.pool(x)
        features = self.flatten(x)  # Flatten the output to a 2D tensor
        features = self.dropout(features)  # Apply dropout at the end
        return features


In [24]:
class AttentionMILModel(nn.Module):
    def __init__(self, image_feature_dim_2d, image_feature_dim_3d, 
                 feature_extractor_2d: nn.Module, feature_extractor_3d: nn.Module,
                 age_height_weight_input_size=4, dropout_rate=0.3):
        super(AttentionMILModel, self).__init__()
        self.feature_extractor_2d = feature_extractor_2d
        self.feature_extractor_3d = feature_extractor_3d

        self.age_height_weight_mlp = nn.Sequential(
            nn.Linear(age_height_weight_input_size, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        self.combined_feature_dim = image_feature_dim_2d + image_feature_dim_3d + 32
        self.dropout = nn.Dropout(dropout_rate)

        # Attention + projection for T5
        self.attention_2d = nn.Sequential(
            nn.Linear(self.combined_feature_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )
        self.feature_proj = nn.Linear(self.combined_feature_dim, 512)  # ✅ T5 input projection

    def forward(self, inputs_2d, inputs_3d, age_height_weight):
        features_2d = self.feature_extractor_2d(inputs_2d)
        features_3d = self.feature_extractor_3d(inputs_3d)
        features_meta = self.age_height_weight_mlp(age_height_weight)

        features = torch.cat([features_2d, features_3d, features_meta], dim=1)

        attention_weights = self.attention_2d(features)
        attention_weights = F.softmax(attention_weights, dim=1)
        attention_weights = self.dropout(attention_weights)

        attended = features * attention_weights
        attended = self.dropout(attended)

        projected = self.feature_proj(attended)  # [B, 512]
        return projected  # ❗ T5 decoder에 넘기기 전에는 .unsqueeze(1) 필요


In [25]:
from transformers import T5Model, T5Config
import torch.nn as nn
import torch.nn.functional as F
import torch

class SwiGLUFFN(nn.Module):
    def __init__(self, hidden_size, intermediate_size, dropout_rate=0.1):
        super(SwiGLUFFN, self).__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size)
        self.up_proj = nn.Linear(hidden_size, intermediate_size)
        self.down_proj = nn.Linear(intermediate_size, hidden_size)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        gate = F.silu(self.gate_proj(x))
        up = self.up_proj(x)
        x = gate * up
        x = self.down_proj(x)
        x = self.dropout(x)
        return x

class ReducedDecoderT5SwiGLU(nn.Module):
    def __init__(self, vocab_size, max_seq_length=50, dropout_rate=0.7):
        super(ReducedDecoderT5SwiGLU, self).__init__()
        self.max_seq_length = max_seq_length

        config = T5Config(
            vocab_size=vocab_size,
            d_model=512,
            d_ff=2048,
            num_layers=6,
            num_heads=8,
            dropout_rate=dropout_rate,
            is_encoder_decoder=True,
            output_attentions=True,
            return_dict=True
        )

        self.transformer = T5Model(config)
        self.lm_head = nn.Linear(config.d_model, vocab_size, bias=False)

        # Replace FFN with SwiGLUFFN
        for block in self.transformer.decoder.block:
            hidden_size = config.d_model
            intermediate_size = config.d_ff
            block.layer[1].DenseReluDense = SwiGLUFFN(hidden_size, intermediate_size, dropout_rate)

    def forward(self, encoder_hidden_states, decoder_input_ids, decoder_attention_mask=None):
        decoder_outputs = self.transformer(
            encoder_outputs=(encoder_hidden_states,),  # <-- 핵심
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask
        )
        sequence_output = decoder_outputs.last_hidden_state
        logits = self.lm_head(sequence_output)
        return logits


In [26]:
with open(params['vocab_path'], 'rb') as f:
        vocab = pickle.load(f)

transform = transforms.Compose([ 
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))  # Adjust for grayscale images with 1 channel
])

In [27]:
Feature_Extractor2D = FeatureExtractor2D()
Feature_Extractor3D = FeatureExtractor3D()

# Initialize encoder and new decoder
encoder = AttentionMILModel(65536, 196608, Feature_Extractor2D, Feature_Extractor3D).to(device)
decoder = ReducedDecoderT5SwiGLU(len(vocab)).to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
model_param = list(decoder.parameters()) + list(encoder.parameters())
optimizer = torch.optim.Adam(model_param, lr=params['lr'], betas=(params['beta1'], params['beta2']), weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True)

In [28]:
# EfficientNet B3의 FeatureExtractor2D 출력 크기 확인
sample_input = torch.randn(1, 1, params['image_size'], params['image_size']).to(device)
features_2d = Feature_Extractor2D(sample_input)
print("FeatureExtractor2D feature size:", features_2d.shape)

FeatureExtractor2D feature size: torch.Size([1, 65536])


In [29]:
# FeatureExtractor3D의 출력 크기 확인
sample_ct_input = torch.randn(1, 1, params['image_count'], params['image_size'], params['image_size']).to(device)
features_3d = Feature_Extractor3D(sample_ct_input)
print("FeatureExtractor3D output size:", features_3d.shape)

FeatureExtractor3D output size: torch.Size([1, 196608])


# caching data

In [13]:
train_dataset=CustomDataset(params['data_path'],params['image_count'],params['image_size'],params['train_csv'],'train',vocab,transform=transform)
test_dataset=CustomDataset(params['data_path'],params['image_count'],params['image_size'],params['val_csv'],'val',vocab,transform=transform)
train_dataloader=DataLoader(train_dataset,batch_size=params['batch_size'],shuffle=True,collate_fn=collate_fn)
val_dataloader=DataLoader(test_dataset,batch_size=params['batch_size'],shuffle=True,collate_fn=collate_fn)

Caching data: 100%|██████████| 649/649 [33:17<00:00,  3.08s/it] 
Caching data: 100%|██████████| 163/163 [05:04<00:00,  1.87s/it]


In [None]:
best_val_loss = float("inf")

for epoch in range(params['epochs']):
    # Training loop
    train = tqdm(train_dataloader)
    count = 0
    train_loss = 0.0
    decoder.train()  # Ensure the decoder is in training mode
    encoder.train()
    
    for ct_images, cr_images, captions, lengths, sexs, ages, heights, weights in train:
        ct_images, cr_images, captions = ct_images.to(device), cr_images.to(device), captions.to(device)
        sexs, ages, heights, weights = sexs.to(device), ages.to(device), heights.to(device), weights.to(device)
        count += 1

        # Forward pass for encoder
        features = encoder(cr_images, ct_images, torch.stack([sexs, ages, heights, weights], dim=1))
        features = features.unsqueeze(1)  # [B, D] → [B, 1, D]
        
        # Prepare inputs for T5 decoder
        labels = captions[:, 1:]          # Remove the first token for labels

        # Forward pass through decoder with T5 format
        outputs = decoder(encoder_hidden_states=features, decoder_input_ids=labels)
        loss = criterion(outputs.view(-1, outputs.size(-1)), labels.contiguous().view(-1))

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()

        train.set_description(f"epoch: {epoch+1}/{params['epochs']} Step: {count+1} loss : {train_loss/count:.4f} ")

    # Validation loop
    with torch.no_grad():
        val_count = 0
        val_loss = 0.0
        val = tqdm(val_dataloader)
        decoder.eval()  # Ensure the decoder is in evaluation mode
        encoder.eval()
        
        for ct_images, cr_images, captions, lengths, sexs, ages, heights, weights in val:
            val_count += 1
            ct_images, cr_images, captions = ct_images.to(device), cr_images.to(device), captions.to(device)
            sexs, ages, heights, weights = sexs.to(device), ages.to(device), heights.to(device), weights.to(device)

            # Forward pass for validation
            features = encoder(cr_images, ct_images, torch.stack([sexs, ages, heights, weights], dim=1))
            features = features.unsqueeze(1)  # [B, D] → [B, 1, D]
            
            # Validation inputs
            labels = captions[:, 1:]

            # Validation forward pass through decoder
            outputs = decoder(encoder_hidden_states=features, decoder_input_ids=labels)
            loss = criterion(outputs.view(-1, outputs.size(-1)), labels.contiguous().view(-1))
            val_loss += loss.item()
            val.set_description(f"epoch: {epoch+1}/{params['epochs']} Step: {val_count+1} loss : {val_loss/val_count:.4f} ")
        
        # Average validation loss
        avg_val_loss = val_loss / len(val_dataloader)
        print(f"Epoch [{epoch+1}/{params['epochs']}], Validation Loss: {avg_val_loss:.4f}")
        print("Current learning rate:", optimizer.param_groups[0]['lr'])

        # Call scheduler with validation loss
        scheduler.step(avg_val_loss)

        # Save the model if it has the best validation loss so far
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(encoder.state_dict(), '../model/test_att_encoder.pth')
            torch.save(decoder.state_dict(), '../model/test_att_decoder.pth')
            print("Best model saved with Validation Loss:", best_val_loss)


  0%|          | 0/130 [00:00<?, ?it/s]

epoch: 1/300 Step: 131 loss : 0.3850 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 1/300 Step: 34 loss : 0.1400 : 100%|██████████| 33/33 [00:06<00:00,  4.86it/s]


Epoch [1/300], Validation Loss: 0.1400
Current learning rate: 0.0001
Best model saved with Validation Loss: 4.620517782866955


epoch: 2/300 Step: 131 loss : 0.3740 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 2/300 Step: 34 loss : 0.1332 : 100%|██████████| 33/33 [00:06<00:00,  4.91it/s]


Epoch [2/300], Validation Loss: 0.1332
Current learning rate: 0.0001
Best model saved with Validation Loss: 4.394742090255022


epoch: 3/300 Step: 131 loss : 0.3620 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 3/300 Step: 34 loss : 0.1386 : 100%|██████████| 33/33 [00:06<00:00,  4.92it/s]


Epoch [3/300], Validation Loss: 0.1386
Current learning rate: 0.0001


epoch: 4/300 Step: 131 loss : 0.3520 : 100%|██████████| 130/130 [01:40<00:00,  1.30it/s]
epoch: 4/300 Step: 34 loss : 0.1273 : 100%|██████████| 33/33 [00:06<00:00,  4.99it/s]


Epoch [4/300], Validation Loss: 0.1273
Current learning rate: 0.0001
Best model saved with Validation Loss: 4.20126311853528


epoch: 5/300 Step: 131 loss : 0.3392 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 5/300 Step: 34 loss : 0.1177 : 100%|██████████| 33/33 [00:06<00:00,  4.85it/s]


Epoch [5/300], Validation Loss: 0.1177
Current learning rate: 0.0001
Best model saved with Validation Loss: 3.8852624744176865


epoch: 6/300 Step: 131 loss : 0.3930 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 6/300 Step: 34 loss : 0.1112 : 100%|██████████| 33/33 [00:06<00:00,  4.83it/s]


Epoch [6/300], Validation Loss: 0.1112
Current learning rate: 0.0001
Best model saved with Validation Loss: 3.669293670915067


epoch: 7/300 Step: 131 loss : 0.3144 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 7/300 Step: 34 loss : 0.1037 : 100%|██████████| 33/33 [00:06<00:00,  4.91it/s]


Epoch [7/300], Validation Loss: 0.1037
Current learning rate: 0.0001
Best model saved with Validation Loss: 3.423365879803896


epoch: 8/300 Step: 131 loss : 0.3051 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 8/300 Step: 34 loss : 0.1095 : 100%|██████████| 33/33 [00:06<00:00,  4.91it/s]


Epoch [8/300], Validation Loss: 0.1095
Current learning rate: 0.0001


epoch: 9/300 Step: 131 loss : 0.2935 : 100%|██████████| 130/130 [01:40<00:00,  1.30it/s]
epoch: 9/300 Step: 34 loss : 0.1040 : 100%|██████████| 33/33 [00:06<00:00,  5.00it/s]


Epoch [9/300], Validation Loss: 0.1040
Current learning rate: 0.0001


epoch: 10/300 Step: 131 loss : 0.2897 : 100%|██████████| 130/130 [01:40<00:00,  1.30it/s]
epoch: 10/300 Step: 34 loss : 0.1038 : 100%|██████████| 33/33 [00:06<00:00,  4.97it/s]


Epoch [10/300], Validation Loss: 0.1038
Current learning rate: 5e-05


epoch: 11/300 Step: 131 loss : 0.2852 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 11/300 Step: 34 loss : 0.1037 : 100%|██████████| 33/33 [00:06<00:00,  4.77it/s]


Epoch [11/300], Validation Loss: 0.1037
Current learning rate: 5e-05
Best model saved with Validation Loss: 3.4229789823293686


epoch: 12/300 Step: 131 loss : 0.2756 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 12/300 Step: 34 loss : 0.0989 : 100%|██████████| 33/33 [00:06<00:00,  4.85it/s]


Epoch [12/300], Validation Loss: 0.0989
Current learning rate: 5e-05
Best model saved with Validation Loss: 3.2633445411920547


epoch: 13/300 Step: 131 loss : 0.2745 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 13/300 Step: 34 loss : 0.1028 : 100%|██████████| 33/33 [00:06<00:00,  4.86it/s]


Epoch [13/300], Validation Loss: 0.1028
Current learning rate: 5e-05


epoch: 14/300 Step: 131 loss : 0.2659 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 14/300 Step: 34 loss : 0.0948 : 100%|██████████| 33/33 [00:06<00:00,  4.79it/s]


Epoch [14/300], Validation Loss: 0.0948
Current learning rate: 5e-05
Best model saved with Validation Loss: 3.1277591474354267


epoch: 15/300 Step: 131 loss : 0.2589 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 15/300 Step: 34 loss : 0.0932 : 100%|██████████| 33/33 [00:06<00:00,  4.99it/s]


Epoch [15/300], Validation Loss: 0.0932
Current learning rate: 5e-05
Best model saved with Validation Loss: 3.074034094810486


epoch: 16/300 Step: 131 loss : 0.2561 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 16/300 Step: 34 loss : 0.0876 : 100%|██████████| 33/33 [00:06<00:00,  4.87it/s]


Epoch [16/300], Validation Loss: 0.0876
Current learning rate: 5e-05
Best model saved with Validation Loss: 2.892386980354786


epoch: 17/300 Step: 131 loss : 0.2557 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 17/300 Step: 34 loss : 0.0878 : 100%|██████████| 33/33 [00:06<00:00,  4.82it/s]


Epoch [17/300], Validation Loss: 0.0878
Current learning rate: 5e-05


epoch: 18/300 Step: 131 loss : 0.2514 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 18/300 Step: 34 loss : 0.0905 : 100%|██████████| 33/33 [00:06<00:00,  4.91it/s]


Epoch [18/300], Validation Loss: 0.0905
Current learning rate: 5e-05


epoch: 19/300 Step: 131 loss : 0.2424 : 100%|██████████| 130/130 [01:40<00:00,  1.30it/s]
epoch: 19/300 Step: 34 loss : 0.0903 : 100%|██████████| 33/33 [00:06<00:00,  4.89it/s]


Epoch [19/300], Validation Loss: 0.0903
Current learning rate: 2.5e-05


epoch: 20/300 Step: 131 loss : 0.2449 : 100%|██████████| 130/130 [01:40<00:00,  1.30it/s]
epoch: 20/300 Step: 34 loss : 0.0897 : 100%|██████████| 33/33 [00:06<00:00,  4.91it/s]


Epoch [20/300], Validation Loss: 0.0897
Current learning rate: 2.5e-05


epoch: 21/300 Step: 131 loss : 0.2394 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 21/300 Step: 34 loss : 0.0834 : 100%|██████████| 33/33 [00:06<00:00,  4.90it/s]


Epoch [21/300], Validation Loss: 0.0834
Current learning rate: 1.25e-05
Best model saved with Validation Loss: 2.7535200230777264


epoch: 22/300 Step: 131 loss : 0.2361 : 100%|██████████| 130/130 [01:40<00:00,  1.30it/s]
epoch: 22/300 Step: 34 loss : 0.0854 : 100%|██████████| 33/33 [00:06<00:00,  4.87it/s]


Epoch [22/300], Validation Loss: 0.0854
Current learning rate: 1.25e-05


epoch: 23/300 Step: 131 loss : 0.2376 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 23/300 Step: 34 loss : 0.0843 : 100%|██████████| 33/33 [00:06<00:00,  5.26it/s]


Epoch [23/300], Validation Loss: 0.0843
Current learning rate: 1.25e-05


epoch: 24/300 Step: 131 loss : 0.2358 : 100%|██████████| 130/130 [01:40<00:00,  1.30it/s]
epoch: 24/300 Step: 34 loss : 0.0862 : 100%|██████████| 33/33 [00:06<00:00,  4.94it/s]


Epoch [24/300], Validation Loss: 0.0862
Current learning rate: 6.25e-06


epoch: 25/300 Step: 131 loss : 0.2369 : 100%|██████████| 130/130 [01:39<00:00,  1.30it/s]
epoch: 25/300 Step: 34 loss : 0.0841 : 100%|██████████| 33/33 [00:06<00:00,  5.19it/s]


Epoch [25/300], Validation Loss: 0.0841
Current learning rate: 6.25e-06


epoch: 26/300 Step: 131 loss : 0.2339 : 100%|██████████| 130/130 [01:39<00:00,  1.30it/s]
epoch: 26/300 Step: 34 loss : 0.0865 : 100%|██████████| 33/33 [00:06<00:00,  5.02it/s]


Epoch [26/300], Validation Loss: 0.0865
Current learning rate: 3.125e-06


epoch: 27/300 Step: 131 loss : 0.2406 : 100%|██████████| 130/130 [01:40<00:00,  1.30it/s]
epoch: 27/300 Step: 34 loss : 0.0842 : 100%|██████████| 33/33 [00:06<00:00,  4.93it/s]


Epoch [27/300], Validation Loss: 0.0842
Current learning rate: 3.125e-06


epoch: 28/300 Step: 131 loss : 0.2366 : 100%|██████████| 130/130 [01:41<00:00,  1.28it/s]
epoch: 28/300 Step: 34 loss : 0.0866 : 100%|██████████| 33/33 [00:06<00:00,  4.92it/s]


Epoch [28/300], Validation Loss: 0.0866
Current learning rate: 1.5625e-06


epoch: 29/300 Step: 131 loss : 0.2362 : 100%|██████████| 130/130 [01:40<00:00,  1.30it/s]
epoch: 29/300 Step: 34 loss : 0.0836 : 100%|██████████| 33/33 [00:06<00:00,  4.92it/s]


Epoch [29/300], Validation Loss: 0.0836
Current learning rate: 1.5625e-06


epoch: 30/300 Step: 131 loss : 0.2345 : 100%|██████████| 130/130 [01:39<00:00,  1.30it/s]
epoch: 30/300 Step: 34 loss : 0.0837 : 100%|██████████| 33/33 [00:06<00:00,  4.72it/s]


Epoch [30/300], Validation Loss: 0.0837
Current learning rate: 7.8125e-07


epoch: 31/300 Step: 131 loss : 0.2330 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 31/300 Step: 34 loss : 0.0810 : 100%|██████████| 33/33 [00:06<00:00,  5.04it/s]


Epoch [31/300], Validation Loss: 0.0810
Current learning rate: 7.8125e-07
Best model saved with Validation Loss: 2.673199510201812


epoch: 32/300 Step: 131 loss : 0.2352 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 32/300 Step: 34 loss : 0.0860 : 100%|██████████| 33/33 [00:06<00:00,  5.14it/s]


Epoch [32/300], Validation Loss: 0.0860
Current learning rate: 7.8125e-07


epoch: 33/300 Step: 131 loss : 0.2327 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 33/300 Step: 34 loss : 0.0872 : 100%|██████████| 33/33 [00:06<00:00,  5.23it/s]


Epoch [33/300], Validation Loss: 0.0872
Current learning rate: 7.8125e-07


epoch: 34/300 Step: 131 loss : 0.2347 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 34/300 Step: 34 loss : 0.0814 : 100%|██████████| 33/33 [00:06<00:00,  5.32it/s]


Epoch [34/300], Validation Loss: 0.0814
Current learning rate: 3.90625e-07


epoch: 35/300 Step: 131 loss : 0.2336 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 35/300 Step: 34 loss : 0.0821 : 100%|██████████| 33/33 [00:06<00:00,  5.17it/s]


Epoch [35/300], Validation Loss: 0.0821
Current learning rate: 3.90625e-07


epoch: 36/300 Step: 131 loss : 0.2307 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 36/300 Step: 34 loss : 0.0823 : 100%|██████████| 33/33 [00:07<00:00,  4.69it/s]


Epoch [36/300], Validation Loss: 0.0823
Current learning rate: 1.953125e-07


epoch: 37/300 Step: 131 loss : 0.2311 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 37/300 Step: 34 loss : 0.0849 : 100%|██████████| 33/33 [00:06<00:00,  5.06it/s]


Epoch [37/300], Validation Loss: 0.0849
Current learning rate: 1.953125e-07


epoch: 38/300 Step: 131 loss : 0.2349 : 100%|██████████| 130/130 [01:39<00:00,  1.30it/s]
epoch: 38/300 Step: 34 loss : 0.0822 : 100%|██████████| 33/33 [00:06<00:00,  4.87it/s]


Epoch [38/300], Validation Loss: 0.0822
Current learning rate: 9.765625e-08


epoch: 39/300 Step: 131 loss : 0.2355 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 39/300 Step: 34 loss : 0.0839 : 100%|██████████| 33/33 [00:06<00:00,  4.92it/s]


Epoch [39/300], Validation Loss: 0.0839
Current learning rate: 9.765625e-08


epoch: 40/300 Step: 131 loss : 0.2322 : 100%|██████████| 130/130 [01:38<00:00,  1.31it/s]
epoch: 40/300 Step: 34 loss : 0.0833 : 100%|██████████| 33/33 [00:06<00:00,  5.17it/s]


Epoch [40/300], Validation Loss: 0.0833
Current learning rate: 4.8828125e-08


epoch: 41/300 Step: 131 loss : 0.2330 : 100%|██████████| 130/130 [01:40<00:00,  1.30it/s]
epoch: 41/300 Step: 34 loss : 0.0841 : 100%|██████████| 33/33 [00:06<00:00,  5.12it/s]


Epoch [41/300], Validation Loss: 0.0841
Current learning rate: 4.8828125e-08


epoch: 42/300 Step: 131 loss : 0.2350 : 100%|██████████| 130/130 [01:40<00:00,  1.30it/s]
epoch: 42/300 Step: 34 loss : 0.0855 : 100%|██████████| 33/33 [00:06<00:00,  5.27it/s]


Epoch [42/300], Validation Loss: 0.0855
Current learning rate: 2.44140625e-08


epoch: 43/300 Step: 131 loss : 0.2385 : 100%|██████████| 130/130 [01:38<00:00,  1.32it/s]
epoch: 43/300 Step: 34 loss : 0.0840 : 100%|██████████| 33/33 [00:06<00:00,  5.04it/s]


Epoch [43/300], Validation Loss: 0.0840
Current learning rate: 2.44140625e-08


epoch: 44/300 Step: 131 loss : 0.2335 : 100%|██████████| 130/130 [01:39<00:00,  1.30it/s]
epoch: 44/300 Step: 34 loss : 0.0844 : 100%|██████████| 33/33 [00:07<00:00,  4.53it/s]


Epoch [44/300], Validation Loss: 0.0844
Current learning rate: 1.220703125e-08


epoch: 45/300 Step: 131 loss : 0.2345 : 100%|██████████| 130/130 [01:40<00:00,  1.29it/s]
epoch: 45/300 Step: 34 loss : 0.0812 : 100%|██████████| 33/33 [00:07<00:00,  4.54it/s]


Epoch [45/300], Validation Loss: 0.0812
Current learning rate: 1.220703125e-08


epoch: 46/300 Step: 131 loss : 0.2334 : 100%|██████████| 130/130 [01:39<00:00,  1.30it/s]
epoch: 46/300 Step: 34 loss : 0.0843 : 100%|██████████| 33/33 [00:06<00:00,  5.14it/s]


Epoch [46/300], Validation Loss: 0.0843
Current learning rate: 1.220703125e-08


epoch: 47/300 Step: 131 loss : 0.2338 : 100%|██████████| 130/130 [01:39<00:00,  1.30it/s]
epoch: 47/300 Step: 34 loss : 0.0832 : 100%|██████████| 33/33 [00:06<00:00,  5.14it/s]


Epoch [47/300], Validation Loss: 0.0832
Current learning rate: 1.220703125e-08


epoch: 48/300 Step: 131 loss : 0.2310 : 100%|██████████| 130/130 [01:40<00:00,  1.30it/s]
epoch: 48/300 Step: 34 loss : 0.0837 : 100%|██████████| 33/33 [00:06<00:00,  5.16it/s]


Epoch [48/300], Validation Loss: 0.0837
Current learning rate: 1.220703125e-08


epoch: 49/300 Step: 131 loss : 0.2340 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 49/300 Step: 34 loss : 0.0841 : 100%|██████████| 33/33 [00:07<00:00,  4.67it/s]


Epoch [49/300], Validation Loss: 0.0841
Current learning rate: 1.220703125e-08


epoch: 50/300 Step: 131 loss : 0.2325 : 100%|██████████| 130/130 [01:38<00:00,  1.32it/s]
epoch: 50/300 Step: 34 loss : 0.0841 : 100%|██████████| 33/33 [00:06<00:00,  5.33it/s]


Epoch [50/300], Validation Loss: 0.0841
Current learning rate: 1.220703125e-08


epoch: 51/300 Step: 131 loss : 0.2341 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 51/300 Step: 34 loss : 0.0822 : 100%|██████████| 33/33 [00:06<00:00,  5.16it/s]


Epoch [51/300], Validation Loss: 0.0822
Current learning rate: 1.220703125e-08


epoch: 52/300 Step: 131 loss : 0.2330 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 52/300 Step: 34 loss : 0.0847 : 100%|██████████| 33/33 [00:06<00:00,  5.11it/s]


Epoch [52/300], Validation Loss: 0.0847
Current learning rate: 1.220703125e-08


epoch: 53/300 Step: 131 loss : 0.2360 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 53/300 Step: 34 loss : 0.0855 : 100%|██████████| 33/33 [00:06<00:00,  4.85it/s]


Epoch [53/300], Validation Loss: 0.0855
Current learning rate: 1.220703125e-08


epoch: 54/300 Step: 131 loss : 0.2305 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 54/300 Step: 34 loss : 0.0830 : 100%|██████████| 33/33 [00:06<00:00,  5.10it/s]


Epoch [54/300], Validation Loss: 0.0830
Current learning rate: 1.220703125e-08


epoch: 55/300 Step: 131 loss : 0.2319 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 55/300 Step: 34 loss : 0.0839 : 100%|██████████| 33/33 [00:06<00:00,  5.17it/s]


Epoch [55/300], Validation Loss: 0.0839
Current learning rate: 1.220703125e-08


epoch: 56/300 Step: 131 loss : 0.2348 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 56/300 Step: 34 loss : 0.0854 : 100%|██████████| 33/33 [00:06<00:00,  5.09it/s]


Epoch [56/300], Validation Loss: 0.0854
Current learning rate: 1.220703125e-08


epoch: 57/300 Step: 131 loss : 0.2336 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 57/300 Step: 34 loss : 0.0807 : 100%|██████████| 33/33 [00:06<00:00,  5.21it/s]


Epoch [57/300], Validation Loss: 0.0807
Current learning rate: 1.220703125e-08
Best model saved with Validation Loss: 2.663696263451129


epoch: 58/300 Step: 131 loss : 0.2383 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 58/300 Step: 34 loss : 0.0842 : 100%|██████████| 33/33 [00:06<00:00,  4.99it/s]


Epoch [58/300], Validation Loss: 0.0842
Current learning rate: 1.220703125e-08


epoch: 59/300 Step: 131 loss : 0.2320 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 59/300 Step: 34 loss : 0.0844 : 100%|██████████| 33/33 [00:06<00:00,  5.06it/s]


Epoch [59/300], Validation Loss: 0.0844
Current learning rate: 1.220703125e-08


epoch: 60/300 Step: 131 loss : 0.2326 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 60/300 Step: 34 loss : 0.0834 : 100%|██████████| 33/33 [00:06<00:00,  4.76it/s]


Epoch [60/300], Validation Loss: 0.0834
Current learning rate: 1.220703125e-08


epoch: 61/300 Step: 131 loss : 0.2332 : 100%|██████████| 130/130 [01:39<00:00,  1.31it/s]
epoch: 61/300 Step: 34 loss : 0.0829 : 100%|██████████| 33/33 [00:06<00:00,  5.07it/s]


Epoch [61/300], Validation Loss: 0.0829
Current learning rate: 1.220703125e-08


epoch: 62/300 Step: 43 loss : 0.2297 :  32%|███▏      | 42/130 [00:32<01:08,  1.28it/s]