In [1]:
import os,json
import torch
import torch.nn as nn
from changechat.model.multimodal_encoder.clip_encoder import CLIPVisionTower
from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.optim as optim
from torch.nn import CrossEntropyLoss

  from .autonotebook import tqdm as notebook_tqdm


[2024-08-07 22:17:39,155] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


## 定义模型

In [2]:
import math
class Classifier(nn.Module):
    def __init__(
        self, mm_hidden_size, hidden_size=None, img_feature_w=18, img_feature_h=18
    ):
        super(Classifier, self).__init__()
        self.d_model = mm_hidden_size

        encoder_self_layer_classifier = nn.TransformerEncoderLayer(
            2 * self.d_model, nhead=2, dim_feedforward=int(2 * self.d_model)
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_self_layer_classifier, num_layers=1
        )
        self.w_embedding = nn.Embedding(img_feature_w, int(self.d_model / 2))
        self.h_embedding = nn.Embedding(img_feature_h, int(self.d_model / 2))
        self.classifier_head = nn.Linear(2 * self.d_model, 2)

        # cls_token
        scale = self.d_model**-0.5
        self.cls_changeflag = nn.Parameter(scale * torch.randn(1, 2 * self.d_model))

    def position_embedding_2D_func(self, img_feat_A, img_feat_B):
        device = img_feat_A.device
        batch = img_feat_B.shape[0]
        Len_feat = img_feat_B.shape[1]
        h = int(math.sqrt(Len_feat))
        w = h
        pos_w = torch.arange(w, device=device).to(device)
        pos_h = torch.arange(h, device=device).to(device)
        embed_w = self.w_embedding(pos_w)
        embed_h = self.h_embedding(pos_h)
        position_embedding = torch.cat(
            [
                embed_w.unsqueeze(0).repeat(h, 1, 1),
                embed_h.unsqueeze(1).repeat(1, w, 1),
            ],
            dim=-1,
        )
        position_embedding = position_embedding.unsqueeze(0).repeat(
            batch, 1, 1, 1
        )  # (batch, h, w, d_model)
        position_embedding = position_embedding.view(batch, -1, self.d_model)
        img_feat_A = img_feat_A + position_embedding  # NLD
        img_feat_B = img_feat_B + position_embedding  # NLD

        return img_feat_A, img_feat_B

    def forward(self, img_feat):
        img_feat_A = img_feat[:, 0, ...]  # (N,L,768)
        img_feat_B = img_feat[:, 1, ...]  # (N,L,768)

        # 2D image position_embedding
        img_feat_A, img_feat_B = self.position_embedding_2D_func(
            img_feat_A, img_feat_B
        )  # (N, L, D)

        img_feat = torch.cat([img_feat_A, img_feat_B], dim=-1)  # (N, L, 2D)
        img_feat_with_cls = torch.cat(
            [
                self.cls_changeflag.unsqueeze(0).expand(
                    img_feat.shape[0], *self.cls_changeflag.shape
                ),
                img_feat,
            ],
            dim=1,
        )

        img_feat_with_cls = self.transformer_encoder(
            img_feat_with_cls.permute(1, 0, 2)
        ).permute(
            1, 0, 2
        )  # (N, L, 2D)
        change_pred = self.classifier_head(img_feat_with_cls[:, 0, :])
        return change_pred

In [3]:
@dataclass
class ChangeClassifierConfig:
    mm_vision_tower = "/root/autodl-tmp/GeoChat/hf-models/clip-vit-large-patch14-336"
    mm_vision_select_layer = -2

class ChangeClassifier(nn.Module):
    def __init__(self, config):
        super(ChangeClassifier, self).__init__()
        self.vision_tower = CLIPVisionTower(config.mm_vision_tower, config)
        # freeeze vision_tower
        self.vision_tower.requires_grad_(False)

        if config.mm_vision_tower.endswith("clip-vit-large-patch14-336"):
            mm_hidden_size = 1024
        else:
            raise NotImplementedError
        self.classifier = Classifier(mm_hidden_size = mm_hidden_size)

    def forward(self, images):
        # 提取两幅图像的特征
        assert images.ndim == 5
        concat_images = torch.cat([image for image in images], dim=0)
        image_features = self.vision_tower(concat_images)
        split_sizes = [image.shape[0] for image in images]
        image_features = torch.split(
            image_features, split_sizes, dim=0
        )  # b tuples of [2, N, L]
        image_features = torch.stack(image_features, dim=0)  # [b, 2, N, L]
        change_pred = self.classifier(image_features)
        return change_pred

## 定义数据集

In [9]:
@dataclass
class DataConfig:
    data_path = "/root/autodl-tmp/LEVIR-MCI-dataset/ChangeChat_classify.json"
    image_folder = "/root/autodl-tmp/LEVIR-MCI-dataset/images"
    image_processor = None

class LazySupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(
        self,
        config
    ):
        super(LazySupervisedDataset, self).__init__()
        list_data_dict = json.load(open(config.data_path, "r"))
        self.list_data_dict = list_data_dict
        self.processor = config.image_processor
        self.image_folder = config.image_folder

    def __len__(self):
        return len(self.list_data_dict)

    def __getitem__(self, i):
        sources = self.list_data_dict[i]
        if isinstance(i, int):
            sources = [sources]
        assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME
        image_file = self.list_data_dict[i]["image"]
        image_folder = self.image_folder
        processor = self.processor

        if isinstance(image_file, str):
            image_file_list = [image_file]
        elif isinstance(image_file, list):
            image_file_list = image_file

        imageList = []
        for _image_file in image_file_list:
            image = Image.open(
                (os.path.join(image_folder, _image_file)).strip()
            ).convert("RGB")
            image = processor.preprocess(
                image,
                do_resize=True,
                crop_size={"height": 252, "width": 252},
                size={"shortest_edge": 252},
                return_tensors="pt",
            )["pixel_values"][0]
            imageList.append(image)

        data_dict = dict()
        # 将多幅图像拼成一个Tensor
        data_dict["images"] = torch.stack(imageList, dim=0)  # (2, c, h, w)

        # 如果有变化标签
        if "changeflag" in self.list_data_dict[i]:
            data_dict["change_labels"] = torch.tensor(
                self.list_data_dict[i]["changeflag2"]
            )
        return data_dict

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = ChangeClassifierConfig()
model = ChangeClassifier(config)
model = model.to(device)
# 创建数据集和数据加载器
dataset_config = DataConfig()
dataset_config.image_processor = model.vision_tower.image_processor
train_dataset = LazySupervisedDataset(dataset_config)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 验证集
dataset_config.data_path = "/root/autodl-tmp/GeoChat/load/Test_CC_gt.json"
test_dataset = LazySupervisedDataset(dataset_config)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

params = [v for v in model.parameters() if v.requires_grad]
num_params = sum([v.numel() for v in model.parameters() if v.requires_grad])/1000000
optimizer = optim.Adam(params, lr=1e-4)
criterion = CrossEntropyLoss()

In [6]:
import wandb
wandb.init(project="change_classifier", config=config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhanlinwu[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
import tqdm
def calculate_accuracy(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm.tqdm(data_loader):
            images, labels = batch['images'].to(device), batch['change_labels'].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return accuracy

In [11]:
val_accuracy = calculate_accuracy(model, test_loader, device)
print(f"Validation Accuracy: {val_accuracy:.4f}")

  0%|          | 0/16 [00:00<?, ?it/s]


KeyError: 'changeflag2'

In [8]:
import datetime
num_epochs = 400
print_step = 10
global_step = 0
output_dir = "/root/autodl-tmp/GeoChat/experiments/classifier"
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = output_dir + now
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, batch in enumerate(train_loader):
        images, labels = batch['images'].to(device), batch['change_labels'].to(device)
        
        # 零梯度
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(images)
        
        # 计算损失
        loss = criterion(outputs, labels)
        
        # 反向传播
        loss.backward()
        
        # 更新参数
        optimizer.step()
        global_step += 1
        
        running_loss += loss.item()

        if (i+1)%print_step == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")
            wandb.log({
                "epoch": epoch + 1,
                "global_step": global_step,
                "loss": loss.item()
            })
    
    # 打印每个epoch的平均损失
    val_accuracy = calculate_accuracy(model, test_loader, device)
    print(f"Epoch [{epoch+1}/{num_epochs}] Validation Accuracy: {val_accuracy:.4f}")
    torch.save(model.classifier.state_dict(), os.path.join(output_dir,"classifier_weight.pth"))
    wandb.log({
        "epoch": epoch + 1,
        "val_accuracy": val_accuracy,
    })

wandb.finish()

Epoch [1/400], Step [10/213], Loss: 0.6896
Epoch [1/400], Step [20/213], Loss: 0.8022
Epoch [1/400], Step [30/213], Loss: 0.5419
Epoch [1/400], Step [40/213], Loss: 0.6253
Epoch [1/400], Step [50/213], Loss: 0.4691
Epoch [1/400], Step [60/213], Loss: 0.3945
Epoch [1/400], Step [70/213], Loss: 0.3566
Epoch [1/400], Step [80/213], Loss: 0.3365
Epoch [1/400], Step [90/213], Loss: 0.2182
Epoch [1/400], Step [100/213], Loss: 0.1896
Epoch [1/400], Step [110/213], Loss: 0.3334
Epoch [1/400], Step [120/213], Loss: 0.3010
Epoch [1/400], Step [130/213], Loss: 0.2211
Epoch [1/400], Step [140/213], Loss: 0.1551
Epoch [1/400], Step [150/213], Loss: 0.5036
Epoch [1/400], Step [160/213], Loss: 0.1395
Epoch [1/400], Step [170/213], Loss: 0.3048
Epoch [1/400], Step [180/213], Loss: 0.2300
Epoch [1/400], Step [190/213], Loss: 0.2855
Epoch [1/400], Step [200/213], Loss: 0.2224
Epoch [1/400], Step [210/213], Loss: 0.2734


100%|██████████| 16/16 [01:13<00:00,  4.59s/it]


Epoch [1/400] Validation Accuracy: 0.9160
Epoch [2/400], Step [10/213], Loss: 0.1945
Epoch [2/400], Step [20/213], Loss: 0.0991
Epoch [2/400], Step [30/213], Loss: 0.1677
Epoch [2/400], Step [40/213], Loss: 0.2006
Epoch [2/400], Step [50/213], Loss: 0.2458
Epoch [2/400], Step [60/213], Loss: 0.2479
Epoch [2/400], Step [70/213], Loss: 0.1665
Epoch [2/400], Step [80/213], Loss: 0.2897
Epoch [2/400], Step [90/213], Loss: 0.2025
Epoch [2/400], Step [100/213], Loss: 0.3231
Epoch [2/400], Step [110/213], Loss: 0.2127
Epoch [2/400], Step [120/213], Loss: 0.1730
Epoch [2/400], Step [130/213], Loss: 0.2556
Epoch [2/400], Step [140/213], Loss: 0.4395
Epoch [2/400], Step [150/213], Loss: 0.3377
Epoch [2/400], Step [160/213], Loss: 0.1409
Epoch [2/400], Step [170/213], Loss: 0.1532
Epoch [2/400], Step [180/213], Loss: 0.2074
Epoch [2/400], Step [190/213], Loss: 0.0988
Epoch [2/400], Step [200/213], Loss: 0.1294
Epoch [2/400], Step [210/213], Loss: 0.0997


100%|██████████| 16/16 [01:13<00:00,  4.58s/it]


Epoch [2/400] Validation Accuracy: 0.8979
Epoch [3/400], Step [10/213], Loss: 0.1188
Epoch [3/400], Step [20/213], Loss: 0.1742
Epoch [3/400], Step [30/213], Loss: 0.1341
Epoch [3/400], Step [40/213], Loss: 0.2538
Epoch [3/400], Step [50/213], Loss: 0.1759
Epoch [3/400], Step [60/213], Loss: 0.3366
Epoch [3/400], Step [70/213], Loss: 0.2801
Epoch [3/400], Step [80/213], Loss: 0.1236
Epoch [3/400], Step [90/213], Loss: 0.2401
Epoch [3/400], Step [100/213], Loss: 0.1453
Epoch [3/400], Step [110/213], Loss: 0.0829
Epoch [3/400], Step [120/213], Loss: 0.2364
Epoch [3/400], Step [130/213], Loss: 0.1724
Epoch [3/400], Step [140/213], Loss: 0.0853
Epoch [3/400], Step [150/213], Loss: 0.1240
Epoch [3/400], Step [160/213], Loss: 0.1341
Epoch [3/400], Step [170/213], Loss: 0.2044
Epoch [3/400], Step [180/213], Loss: 0.2015
Epoch [3/400], Step [190/213], Loss: 0.0971
Epoch [3/400], Step [200/213], Loss: 0.3246
Epoch [3/400], Step [210/213], Loss: 0.3627


100%|██████████| 16/16 [01:13<00:00,  4.59s/it]


Epoch [3/400] Validation Accuracy: 0.9295
Epoch [4/400], Step [10/213], Loss: 0.0981
Epoch [4/400], Step [20/213], Loss: 0.1626
Epoch [4/400], Step [30/213], Loss: 0.1158
Epoch [4/400], Step [40/213], Loss: 0.1139
Epoch [4/400], Step [50/213], Loss: 0.2616
Epoch [4/400], Step [60/213], Loss: 0.0821
Epoch [4/400], Step [70/213], Loss: 0.1285
Epoch [4/400], Step [80/213], Loss: 0.5139
Epoch [4/400], Step [90/213], Loss: 0.4061
Epoch [4/400], Step [100/213], Loss: 0.1931
Epoch [4/400], Step [110/213], Loss: 0.3040
Epoch [4/400], Step [120/213], Loss: 0.1070
Epoch [4/400], Step [130/213], Loss: 0.0936
Epoch [4/400], Step [140/213], Loss: 0.1552
Epoch [4/400], Step [150/213], Loss: 0.2354
Epoch [4/400], Step [160/213], Loss: 0.0799
Epoch [4/400], Step [170/213], Loss: 0.2658
Epoch [4/400], Step [180/213], Loss: 0.0835
Epoch [4/400], Step [190/213], Loss: 0.0682
Epoch [4/400], Step [200/213], Loss: 0.1521
Epoch [4/400], Step [210/213], Loss: 0.2811


100%|██████████| 16/16 [01:13<00:00,  4.60s/it]


Epoch [4/400] Validation Accuracy: 0.9212
Epoch [5/400], Step [10/213], Loss: 0.0387
Epoch [5/400], Step [20/213], Loss: 0.1332
Epoch [5/400], Step [30/213], Loss: 0.0719
Epoch [5/400], Step [40/213], Loss: 0.2544
Epoch [5/400], Step [50/213], Loss: 0.4358
Epoch [5/400], Step [60/213], Loss: 0.2765
Epoch [5/400], Step [70/213], Loss: 0.2152
Epoch [5/400], Step [80/213], Loss: 0.0259
Epoch [5/400], Step [90/213], Loss: 0.1612
Epoch [5/400], Step [100/213], Loss: 0.1747
Epoch [5/400], Step [110/213], Loss: 0.0622
Epoch [5/400], Step [120/213], Loss: 0.1153
Epoch [5/400], Step [130/213], Loss: 0.0777
Epoch [5/400], Step [140/213], Loss: 0.1385
Epoch [5/400], Step [150/213], Loss: 0.0991
Epoch [5/400], Step [160/213], Loss: 0.2675
Epoch [5/400], Step [170/213], Loss: 0.3063
Epoch [5/400], Step [180/213], Loss: 0.1733
Epoch [5/400], Step [190/213], Loss: 0.0807
Epoch [5/400], Step [200/213], Loss: 0.0937
Epoch [5/400], Step [210/213], Loss: 0.3153


100%|██████████| 16/16 [01:13<00:00,  4.60s/it]


Epoch [5/400] Validation Accuracy: 0.9316
Epoch [6/400], Step [10/213], Loss: 0.0411
Epoch [6/400], Step [20/213], Loss: 0.1221
Epoch [6/400], Step [30/213], Loss: 0.1515
Epoch [6/400], Step [40/213], Loss: 0.1563
Epoch [6/400], Step [50/213], Loss: 0.2683
Epoch [6/400], Step [60/213], Loss: 0.0438
Epoch [6/400], Step [70/213], Loss: 0.1784
Epoch [6/400], Step [80/213], Loss: 0.0948
Epoch [6/400], Step [90/213], Loss: 0.0447
Epoch [6/400], Step [100/213], Loss: 0.2085
Epoch [6/400], Step [110/213], Loss: 0.2652
Epoch [6/400], Step [120/213], Loss: 0.1346
Epoch [6/400], Step [130/213], Loss: 0.1303
Epoch [6/400], Step [140/213], Loss: 0.1529
Epoch [6/400], Step [150/213], Loss: 0.0737
Epoch [6/400], Step [160/213], Loss: 0.1404
Epoch [6/400], Step [170/213], Loss: 0.2292
Epoch [6/400], Step [180/213], Loss: 0.0988
Epoch [6/400], Step [190/213], Loss: 0.0769
Epoch [6/400], Step [200/213], Loss: 0.0730
Epoch [6/400], Step [210/213], Loss: 0.1090


100%|██████████| 16/16 [01:13<00:00,  4.61s/it]


Epoch [6/400] Validation Accuracy: 0.9269
Epoch [7/400], Step [10/213], Loss: 0.2317
Epoch [7/400], Step [20/213], Loss: 0.0479
Epoch [7/400], Step [30/213], Loss: 0.0464
Epoch [7/400], Step [40/213], Loss: 0.1298
Epoch [7/400], Step [50/213], Loss: 0.2457
Epoch [7/400], Step [60/213], Loss: 0.1299
Epoch [7/400], Step [70/213], Loss: 0.2764
Epoch [7/400], Step [80/213], Loss: 0.2219
Epoch [7/400], Step [90/213], Loss: 0.1464
Epoch [7/400], Step [100/213], Loss: 0.1600
Epoch [7/400], Step [110/213], Loss: 0.2197
Epoch [7/400], Step [120/213], Loss: 0.1265
Epoch [7/400], Step [130/213], Loss: 0.0911
Epoch [7/400], Step [140/213], Loss: 0.2379
Epoch [7/400], Step [150/213], Loss: 0.0466
Epoch [7/400], Step [160/213], Loss: 0.2267
Epoch [7/400], Step [170/213], Loss: 0.0542
Epoch [7/400], Step [180/213], Loss: 0.1571
Epoch [7/400], Step [190/213], Loss: 0.0500
Epoch [7/400], Step [200/213], Loss: 0.1859
Epoch [7/400], Step [210/213], Loss: 0.1258


100%|██████████| 16/16 [01:13<00:00,  4.60s/it]


Epoch [7/400] Validation Accuracy: 0.9253
Epoch [8/400], Step [10/213], Loss: 0.1153
Epoch [8/400], Step [20/213], Loss: 0.1427
Epoch [8/400], Step [30/213], Loss: 0.0304
Epoch [8/400], Step [40/213], Loss: 0.1252
Epoch [8/400], Step [50/213], Loss: 0.2752
Epoch [8/400], Step [60/213], Loss: 0.1000
Epoch [8/400], Step [70/213], Loss: 0.1970
Epoch [8/400], Step [80/213], Loss: 0.1167
Epoch [8/400], Step [90/213], Loss: 0.1043
Epoch [8/400], Step [100/213], Loss: 0.2243
Epoch [8/400], Step [110/213], Loss: 0.1486
Epoch [8/400], Step [120/213], Loss: 0.0124
Epoch [8/400], Step [130/213], Loss: 0.0732
Epoch [8/400], Step [140/213], Loss: 0.1236
Epoch [8/400], Step [150/213], Loss: 0.0726
Epoch [8/400], Step [160/213], Loss: 0.0514
Epoch [8/400], Step [170/213], Loss: 0.0646
Epoch [8/400], Step [180/213], Loss: 0.0771
Epoch [8/400], Step [190/213], Loss: 0.0464
Epoch [8/400], Step [200/213], Loss: 0.0819
Epoch [8/400], Step [210/213], Loss: 0.1785


100%|██████████| 16/16 [01:13<00:00,  4.58s/it]


Epoch [8/400] Validation Accuracy: 0.9207
Epoch [9/400], Step [10/213], Loss: 0.2004
Epoch [9/400], Step [20/213], Loss: 0.0356
Epoch [9/400], Step [30/213], Loss: 0.0899
Epoch [9/400], Step [40/213], Loss: 0.1588
Epoch [9/400], Step [50/213], Loss: 0.1443
Epoch [9/400], Step [60/213], Loss: 0.1993
Epoch [9/400], Step [70/213], Loss: 0.0545
Epoch [9/400], Step [80/213], Loss: 0.1408
Epoch [9/400], Step [90/213], Loss: 0.0267
Epoch [9/400], Step [100/213], Loss: 0.0607
Epoch [9/400], Step [110/213], Loss: 0.0817
Epoch [9/400], Step [120/213], Loss: 0.0896
Epoch [9/400], Step [130/213], Loss: 0.3660
Epoch [9/400], Step [140/213], Loss: 0.1885
Epoch [9/400], Step [150/213], Loss: 0.0478
Epoch [9/400], Step [160/213], Loss: 0.1151
Epoch [9/400], Step [170/213], Loss: 0.0207
Epoch [9/400], Step [180/213], Loss: 0.1318
Epoch [9/400], Step [190/213], Loss: 0.1763
Epoch [9/400], Step [200/213], Loss: 0.0464
Epoch [9/400], Step [210/213], Loss: 0.0708


100%|██████████| 16/16 [01:13<00:00,  4.60s/it]


Epoch [9/400] Validation Accuracy: 0.9248
Epoch [10/400], Step [10/213], Loss: 0.0927
Epoch [10/400], Step [20/213], Loss: 0.1244
Epoch [10/400], Step [30/213], Loss: 0.1368
Epoch [10/400], Step [40/213], Loss: 0.0518
Epoch [10/400], Step [50/213], Loss: 0.0284
Epoch [10/400], Step [60/213], Loss: 0.1651
Epoch [10/400], Step [70/213], Loss: 0.1515
Epoch [10/400], Step [80/213], Loss: 0.2211
Epoch [10/400], Step [90/213], Loss: 0.1378
Epoch [10/400], Step [100/213], Loss: 0.0793
Epoch [10/400], Step [110/213], Loss: 0.0461
Epoch [10/400], Step [120/213], Loss: 0.2128
Epoch [10/400], Step [130/213], Loss: 0.1913
Epoch [10/400], Step [140/213], Loss: 0.0745
Epoch [10/400], Step [150/213], Loss: 0.1775
Epoch [10/400], Step [160/213], Loss: 0.0511
Epoch [10/400], Step [170/213], Loss: 0.0498
Epoch [10/400], Step [180/213], Loss: 0.0645
Epoch [10/400], Step [190/213], Loss: 0.0922
Epoch [10/400], Step [200/213], Loss: 0.0830
Epoch [10/400], Step [210/213], Loss: 0.0286


100%|██████████| 16/16 [01:13<00:00,  4.59s/it]


Epoch [10/400] Validation Accuracy: 0.9290
Epoch [11/400], Step [10/213], Loss: 0.0667
Epoch [11/400], Step [20/213], Loss: 0.2496
Epoch [11/400], Step [30/213], Loss: 0.1694
Epoch [11/400], Step [40/213], Loss: 0.0427
Epoch [11/400], Step [50/213], Loss: 0.0769
Epoch [11/400], Step [60/213], Loss: 0.0855
Epoch [11/400], Step [70/213], Loss: 0.1030
Epoch [11/400], Step [80/213], Loss: 0.0605
Epoch [11/400], Step [90/213], Loss: 0.0908
Epoch [11/400], Step [100/213], Loss: 0.0445
Epoch [11/400], Step [110/213], Loss: 0.0295
Epoch [11/400], Step [120/213], Loss: 0.0786
Epoch [11/400], Step [130/213], Loss: 0.2304
Epoch [11/400], Step [140/213], Loss: 0.1054
Epoch [11/400], Step [150/213], Loss: 0.1164
Epoch [11/400], Step [160/213], Loss: 0.0823
Epoch [11/400], Step [170/213], Loss: 0.0525
Epoch [11/400], Step [180/213], Loss: 0.1091
Epoch [11/400], Step [190/213], Loss: 0.1448
Epoch [11/400], Step [200/213], Loss: 0.1527
Epoch [11/400], Step [210/213], Loss: 0.2366


100%|██████████| 16/16 [01:13<00:00,  4.62s/it]


Epoch [11/400] Validation Accuracy: 0.9222
Epoch [12/400], Step [10/213], Loss: 0.0356
Epoch [12/400], Step [20/213], Loss: 0.0549
Epoch [12/400], Step [30/213], Loss: 0.1711
Epoch [12/400], Step [40/213], Loss: 0.1379
Epoch [12/400], Step [50/213], Loss: 0.0665
Epoch [12/400], Step [60/213], Loss: 0.0148
Epoch [12/400], Step [70/213], Loss: 0.4212
Epoch [12/400], Step [80/213], Loss: 0.0248
Epoch [12/400], Step [90/213], Loss: 0.1605
Epoch [12/400], Step [100/213], Loss: 0.1190
Epoch [12/400], Step [110/213], Loss: 0.2985
Epoch [12/400], Step [120/213], Loss: 0.0999
Epoch [12/400], Step [130/213], Loss: 0.0315
Epoch [12/400], Step [140/213], Loss: 0.0450
Epoch [12/400], Step [150/213], Loss: 0.1080
Epoch [12/400], Step [160/213], Loss: 0.1532
Epoch [12/400], Step [170/213], Loss: 0.0138
Epoch [12/400], Step [180/213], Loss: 0.0728
Epoch [12/400], Step [190/213], Loss: 0.1033
Epoch [12/400], Step [200/213], Loss: 0.1038
Epoch [12/400], Step [210/213], Loss: 0.2679


100%|██████████| 16/16 [01:13<00:00,  4.60s/it]


Epoch [12/400] Validation Accuracy: 0.9222
Epoch [13/400], Step [10/213], Loss: 0.0763
Epoch [13/400], Step [20/213], Loss: 0.0508
Epoch [13/400], Step [30/213], Loss: 0.0303
Epoch [13/400], Step [40/213], Loss: 0.1611
Epoch [13/400], Step [50/213], Loss: 0.0288
Epoch [13/400], Step [60/213], Loss: 0.0708
Epoch [13/400], Step [70/213], Loss: 0.1549
Epoch [13/400], Step [80/213], Loss: 0.0788
Epoch [13/400], Step [90/213], Loss: 0.0924
Epoch [13/400], Step [100/213], Loss: 0.1281
Epoch [13/400], Step [110/213], Loss: 0.0728
Epoch [13/400], Step [120/213], Loss: 0.0090
Epoch [13/400], Step [130/213], Loss: 0.0870
Epoch [13/400], Step [140/213], Loss: 0.0656
Epoch [13/400], Step [150/213], Loss: 0.1519
Epoch [13/400], Step [160/213], Loss: 0.0901
Epoch [13/400], Step [170/213], Loss: 0.0109
Epoch [13/400], Step [180/213], Loss: 0.2576
Epoch [13/400], Step [190/213], Loss: 0.2005
Epoch [13/400], Step [200/213], Loss: 0.0940
Epoch [13/400], Step [210/213], Loss: 0.1014


100%|██████████| 16/16 [01:13<00:00,  4.58s/it]


Epoch [13/400] Validation Accuracy: 0.9285
Epoch [14/400], Step [10/213], Loss: 0.2033
Epoch [14/400], Step [20/213], Loss: 0.0779
Epoch [14/400], Step [30/213], Loss: 0.0204
Epoch [14/400], Step [40/213], Loss: 0.0716
Epoch [14/400], Step [50/213], Loss: 0.0598
Epoch [14/400], Step [60/213], Loss: 0.0636
Epoch [14/400], Step [70/213], Loss: 0.2213
Epoch [14/400], Step [80/213], Loss: 0.1425
Epoch [14/400], Step [90/213], Loss: 0.0434
Epoch [14/400], Step [100/213], Loss: 0.1575
Epoch [14/400], Step [110/213], Loss: 0.0302
Epoch [14/400], Step [120/213], Loss: 0.0222
Epoch [14/400], Step [130/213], Loss: 0.1738
Epoch [14/400], Step [140/213], Loss: 0.1325
Epoch [14/400], Step [150/213], Loss: 0.0749
Epoch [14/400], Step [160/213], Loss: 0.0173
Epoch [14/400], Step [170/213], Loss: 0.0226
Epoch [14/400], Step [180/213], Loss: 0.0562
Epoch [14/400], Step [190/213], Loss: 0.1882
Epoch [14/400], Step [200/213], Loss: 0.0492
Epoch [14/400], Step [210/213], Loss: 0.1467


100%|██████████| 16/16 [01:13<00:00,  4.59s/it]


Epoch [14/400] Validation Accuracy: 0.9253
Epoch [15/400], Step [10/213], Loss: 0.0195
Epoch [15/400], Step [20/213], Loss: 0.1527
Epoch [15/400], Step [30/213], Loss: 0.0106
Epoch [15/400], Step [40/213], Loss: 0.0640
Epoch [15/400], Step [50/213], Loss: 0.0125
Epoch [15/400], Step [60/213], Loss: 0.0725
Epoch [15/400], Step [70/213], Loss: 0.0916
Epoch [15/400], Step [80/213], Loss: 0.0179
Epoch [15/400], Step [90/213], Loss: 0.1380
Epoch [15/400], Step [100/213], Loss: 0.0582
Epoch [15/400], Step [110/213], Loss: 0.1044
Epoch [15/400], Step [120/213], Loss: 0.0523
Epoch [15/400], Step [130/213], Loss: 0.1372
Epoch [15/400], Step [140/213], Loss: 0.0831
Epoch [15/400], Step [150/213], Loss: 0.0520
Epoch [15/400], Step [160/213], Loss: 0.0103
Epoch [15/400], Step [170/213], Loss: 0.1906
Epoch [15/400], Step [180/213], Loss: 0.0009
Epoch [15/400], Step [190/213], Loss: 0.0348
Epoch [15/400], Step [200/213], Loss: 0.1149
Epoch [15/400], Step [210/213], Loss: 0.4052


100%|██████████| 16/16 [01:13<00:00,  4.61s/it]


Epoch [15/400] Validation Accuracy: 0.9098
Epoch [16/400], Step [10/213], Loss: 0.2592
Epoch [16/400], Step [20/213], Loss: 0.0342
Epoch [16/400], Step [30/213], Loss: 0.0852
Epoch [16/400], Step [40/213], Loss: 0.0125
Epoch [16/400], Step [50/213], Loss: 0.0924
Epoch [16/400], Step [60/213], Loss: 0.0665
Epoch [16/400], Step [70/213], Loss: 0.0665
Epoch [16/400], Step [80/213], Loss: 0.1510
Epoch [16/400], Step [90/213], Loss: 0.0614
Epoch [16/400], Step [100/213], Loss: 0.1650
Epoch [16/400], Step [110/213], Loss: 0.1040
Epoch [16/400], Step [120/213], Loss: 0.1541
Epoch [16/400], Step [130/213], Loss: 0.0115
Epoch [16/400], Step [140/213], Loss: 0.1149
Epoch [16/400], Step [150/213], Loss: 0.0787
Epoch [16/400], Step [160/213], Loss: 0.0603
Epoch [16/400], Step [170/213], Loss: 0.0635
Epoch [16/400], Step [180/213], Loss: 0.0334
Epoch [16/400], Step [190/213], Loss: 0.1313
Epoch [16/400], Step [200/213], Loss: 0.0662
Epoch [16/400], Step [210/213], Loss: 0.0546


100%|██████████| 16/16 [01:13<00:00,  4.60s/it]


Epoch [16/400] Validation Accuracy: 0.9243
Epoch [17/400], Step [10/213], Loss: 0.0809
Epoch [17/400], Step [20/213], Loss: 0.1118
Epoch [17/400], Step [30/213], Loss: 0.0637
Epoch [17/400], Step [40/213], Loss: 0.0353
Epoch [17/400], Step [50/213], Loss: 0.0614
Epoch [17/400], Step [60/213], Loss: 0.0321
Epoch [17/400], Step [70/213], Loss: 0.1235
Epoch [17/400], Step [80/213], Loss: 0.2810
Epoch [17/400], Step [90/213], Loss: 0.0582
Epoch [17/400], Step [100/213], Loss: 0.0258
Epoch [17/400], Step [110/213], Loss: 0.1050
Epoch [17/400], Step [120/213], Loss: 0.0214
Epoch [17/400], Step [130/213], Loss: 0.1971
Epoch [17/400], Step [140/213], Loss: 0.0329
Epoch [17/400], Step [150/213], Loss: 0.0877
Epoch [17/400], Step [160/213], Loss: 0.1038
Epoch [17/400], Step [170/213], Loss: 0.0125
Epoch [17/400], Step [180/213], Loss: 0.1071
Epoch [17/400], Step [190/213], Loss: 0.0677
Epoch [17/400], Step [200/213], Loss: 0.0672
Epoch [17/400], Step [210/213], Loss: 0.0352


100%|██████████| 16/16 [01:13<00:00,  4.60s/it]


Epoch [17/400] Validation Accuracy: 0.9243
Epoch [18/400], Step [10/213], Loss: 0.0475
Epoch [18/400], Step [20/213], Loss: 0.0095
Epoch [18/400], Step [30/213], Loss: 0.1552
Epoch [18/400], Step [40/213], Loss: 0.1455
Epoch [18/400], Step [50/213], Loss: 0.0353
Epoch [18/400], Step [60/213], Loss: 0.0270
Epoch [18/400], Step [70/213], Loss: 0.0652
Epoch [18/400], Step [80/213], Loss: 0.1042
Epoch [18/400], Step [90/213], Loss: 0.0671
Epoch [18/400], Step [100/213], Loss: 0.0010
Epoch [18/400], Step [110/213], Loss: 0.1244
Epoch [18/400], Step [120/213], Loss: 0.0283
Epoch [18/400], Step [130/213], Loss: 0.1755
Epoch [18/400], Step [140/213], Loss: 0.0156
Epoch [18/400], Step [150/213], Loss: 0.0607
Epoch [18/400], Step [160/213], Loss: 0.1087
Epoch [18/400], Step [170/213], Loss: 0.1682
Epoch [18/400], Step [180/213], Loss: 0.0951
Epoch [18/400], Step [190/213], Loss: 0.1952
Epoch [18/400], Step [200/213], Loss: 0.1702
Epoch [18/400], Step [210/213], Loss: 0.0882


100%|██████████| 16/16 [01:13<00:00,  4.58s/it]


Epoch [18/400] Validation Accuracy: 0.9228
Epoch [19/400], Step [10/213], Loss: 0.1810
Epoch [19/400], Step [20/213], Loss: 0.0540
Epoch [19/400], Step [30/213], Loss: 0.0107
Epoch [19/400], Step [40/213], Loss: 0.0455
Epoch [19/400], Step [50/213], Loss: 0.1461
Epoch [19/400], Step [60/213], Loss: 0.1280
Epoch [19/400], Step [70/213], Loss: 0.0813
Epoch [19/400], Step [80/213], Loss: 0.2084
Epoch [19/400], Step [90/213], Loss: 0.0589
Epoch [19/400], Step [100/213], Loss: 0.0477
Epoch [19/400], Step [110/213], Loss: 0.0437
Epoch [19/400], Step [120/213], Loss: 0.0921
Epoch [19/400], Step [130/213], Loss: 0.1261
Epoch [19/400], Step [140/213], Loss: 0.0376


KeyboardInterrupt: 