In [1]:
import model_loader
import pipeline
from PIL import Image
from transformers import CLIPTokenizer
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
import os
from torchvision import transforms
from torchvision.transforms import functional as F

import torch
from torch.optim import  SGD
from tqdm import tqdm
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter


from clip import CLIP
from ddpm import DDPMSampler
from pipeline import generate, get_time_embedding

from pipeline import preprocess_image
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# GPU 설정
device_ids = [1, 2, 3, 4, 5]  
# device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Using device: {device} on GPUs {device_ids}")

device = f"cuda:{device_ids[0]}"  # 첫 번째 GPU를 기본 device로 설정
print(f"Using device: {device} on GPUs {device_ids}")

# 토크나이저 및 모델 로드
tokenizer = CLIPTokenizer("../data/vocab.json", merges_file="../data/merges.txt")
model_file = "../data/v1-5-pruned-emaonly.ckpt"
models = model_loader.preload_models_from_standard_weights(model_file, device)

# Diffusion 모델 병렬화 설정
model = models["diffusion"]
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model, device_ids=device_ids)  # 모델 병렬화
model = model.to(device)

Using device: cuda:1 on GPUs [1, 2, 3, 4, 5]


In [3]:
import os
from PIL import Image
from torch.utils.data import Dataset

class PairedImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.pairs = self._make_pairs()

    def _make_pairs(self):
        pairs = []
        valid_extensions = {".jpg", ".jpeg", ".png"}  # 유효한 이미지 확장자

        # 각 폴더를 순회하면서 blur와 sharp 이미지 쌍을 생성합니다.
        for folder_name in os.listdir(self.root_dir):
            folder_path = os.path.join(self.root_dir, folder_name)
            blur_dir = os.path.join(folder_path, "blur")
            sharp_dir = os.path.join(folder_path, "sharp")
            
            if not (os.path.isdir(blur_dir) and os.path.isdir(sharp_dir)):
                continue
            
            # blur와 sharp 디렉토리에서 동일한 파일 이름을 가진 이미지 쌍을 찾습니다.
            for image_name in os.listdir(blur_dir):
                # 유효한 이미지 파일만 선택
                if not any(image_name.lower().endswith(ext) for ext in valid_extensions):
                    continue
                
                blur_image_path = os.path.join(blur_dir, image_name)
                sharp_image_path = os.path.join(sharp_dir, image_name)
                
                # 두 파일 모두 존재할 때만 추가
                if os.path.isfile(blur_image_path) and os.path.isfile(sharp_image_path):
                    pairs.append((blur_image_path, sharp_image_path))
        
        return pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        blur_image_path, sharp_image_path = self.pairs[idx]
        blur_image = Image.open(blur_image_path).convert("RGB")
        sharp_image = Image.open(sharp_image_path).convert("RGB")
        
        if self.transform:
            blur_image = self.transform(blur_image)
            sharp_image = self.transform(sharp_image)
        
        return blur_image, sharp_image  # (blurred input, sharp target)

# 이미지 중앙 부분 crop
transform = transforms.Compose([
    transforms.CenterCrop((512, 512)), 
    transforms.ToTensor()
])

# 데이터셋 및 데이터 로더
train_data = PairedImageDataset("../images/GOPRO/train", transform=transform)
train_loader = DataLoader(train_data, batch_size=1, shuffle=True)

# Test 데이터셋과 DataLoader 생성
# test_dataset = PairedImageDataset("../images/GOPRO/test", transform=transform)
# test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [4]:
def rescale(x, old_range=(-1, 1), new_range=(0, 1), clamp=False):
    old_min, old_max = old_range
    new_min, new_max = new_range 

    # rescale x from old_range to new_range
    x = (x - old_min) * (new_max - new_min) / (old_max - old_min) + new_min
    if clamp:
        x = torch.clamp(x, min=new_min, max=new_max)  # 값이 new_min과 new_max 사이에 있도록 클램핑
    return x

In [5]:
WIDTH = 512
HEIGHT = 512
LATENTS_WIDTH = WIDTH // 8
LATENTS_HEIGHT = HEIGHT // 8

class LDMFineTuner:
    def __init__(self, models, tokenizer, device, log_dir="runs/ldm_finetune", save_path="./checkpoints"):
        self.models = models
        self.tokenizer = tokenizer
        self.device = device
        self.clip = models['clip'].to(device)
        self.diffusion = models['diffusion'].to(device)
        self.encoder = models['encoder'].to(device)
        self.decoder = models['decoder'].to(device)
        
        # 옵티마이저 설정 (Diffusion 모델만 학습)
        self.optimizer = SGD(self.diffusion.parameters(), lr=1e-5)
        
        self.sampler = DDPMSampler(generator=torch.Generator(device=device))
        self.sampler.set_inference_timesteps(num_inference_steps=1000) 

        # TensorBoard writer 설정
        self.writer = SummaryWriter(log_dir=log_dir)
        
        # 모델과 옵티마이저 저장 경로 설정
        self.save_path = save_path
        os.makedirs(self.save_path, exist_ok=True)  # 디렉토리 생성


    def prepare_batch(self, batch):
        blur_images, sharp_images = batch
        blur_images = blur_images.to(self.device)
        sharp_images = sharp_images.to(self.device)

        # 블러 이미지와 선명한 이미지에 대한 다운샘플링된 노이즈 생성
        batch_size, _, height, width = blur_images.size()
        noise_height, noise_width = height // 8, width // 8  # 다운샘플링 크기에 맞게 설정
        noise_for_blur = torch.randn(batch_size, 4, noise_height, noise_width, device=self.device)
        noise_for_sharp = torch.randn(batch_size, 4, noise_height, noise_width, device=self.device)

        # 블러 이미지를 latent space로 인코딩
        with torch.no_grad():
            blur_latents = self.encoder(blur_images, noise_for_blur)
            sharp_latents = self.encoder(sharp_images, noise_for_sharp)  # Target latents for deblurring
        
        # CLIP을 사용하여 텍스트 임베딩 생성 (기본 프롬프트 사용)
        # 여기서는 임의의 "Deblur image" 프롬프트를 사용하여 context 생성
        tokens = self.tokenizer(["Deblur image"] * blur_images.size(0), 
                                padding="max_length", max_length=77, 
                                return_tensors="pt").input_ids.to(self.device)
        context = self.clip(tokens)
        
        return blur_latents, sharp_latents, context
    
    def train_step(self, blur_latents, sharp_latents, context):
        batch_size = blur_latents.shape[0]
        
        # 랜덤 타임스텝 선택
        t = torch.randint(0, self.sampler.num_train_timesteps, (batch_size,), device=self.device).long()
        
        # 타깃 latents에 노이즈 추가
        noise = torch.randn_like(sharp_latents)
        # 블러 latent에 노이즈 추가
        noisy_latents = self.sampler.add_noise(blur_latents, t)  
        
        # 시간 임베딩
        time_embedding = get_time_embedding(t).to(self.device)
        
        # Diffusion 모델로 노이즈 예측
        predicted_noise = self.diffusion(noisy_latents, context, time_embedding)
        
        # 손실 계산 (예측된 노이즈와 실제 노이즈 간의 차이)
        loss = torch.nn.functional.mse_loss(predicted_noise, noise)
        
        return loss
    
    def save_model(self, epoch):
        # 모델과 옵티마이저 상태를 지정된 경로에 저장
        model_path = os.path.join(self.save_path, f"model_epoch_{epoch}.pth")
        optimizer_path = os.path.join(self.save_path, f"optimizer_epoch_{epoch}.pth")
        
        torch.save(self.diffusion.state_dict(), model_path)
        torch.save(self.optimizer.state_dict(), optimizer_path)
        
        print(f"Model and optimizer saved at epoch {epoch} in {self.save_path}")


    def train(self, dataloader, num_epochs):
        global_step = 0  # 전체 학습 과정에서의 스텝 카운터를 추가

        for epoch in range(num_epochs):
            total_loss = 0
            for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
                blur_images, sharp_images = batch
                blur_latents, sharp_latents, context = self.prepare_batch(batch)
                
                self.optimizer.zero_grad()
                loss = self.train_step(blur_latents, sharp_latents, context)
                loss.backward()
                self.optimizer.step()
                
                # 미니 배치 손실을 TensorBoard에 기록합니다.
                # self.writer.add_scalar("Loss/Train_Batch", loss.item(), global_step)
                
                total_loss += loss.item()
                # global_step += 1 
            
            avg_loss = total_loss / len(dataloader)
            print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
            self.writer.add_scalar("Loss/Train", avg_loss, epoch)

            # 모델과 옵티마이저 저장
            self.save_model(epoch + 1)

            # 샘플 블러 이미지를 선택하여 디블러링 성능을 시각화
            # sample_blur_image = blur_images[0]  # 첫 번째 이미지 선택
            # self.log_generated_images(epoch + 1, sample_blur_image)

    # def log_generated_images(self, epoch):
    #     # 샘플 blur 이미지 생성 후 deblur 결과 기록
    #     with torch.no_grad():
    #         # 간단한 샘플 블러 이미지를 생성하고 복원
    #         sample_blur_latent = torch.randn(1, 4, LATENTS_HEIGHT, LATENTS_WIDTH, device=self.device)
    #         tokens = self.tokenizer(["Deblur image"], padding="max_length", max_length=77, return_tensors="pt").input_ids.to(self.device)
    #         context = self.clip(tokens)

    #         timesteps = reversed(self.sampler.timesteps)
    #         for t in timesteps:
    #             time_embedding = get_time_embedding(t).to(self.device)
    #             predicted_noise = self.diffusion(sample_blur_latent, context, time_embedding)
    #             sample_blur_latent = self.sampler.step(t, sample_blur_latent, predicted_noise)
            
    #         # 디코더로 최종 이미지 생성
    #         generated_image = self.decoder(sample_blur_latent)
    #         generated_image = rescale(generated_image, (-1, 1), (0, 1)).cpu()  # TensorBoard에 출력할 수 있게 rescale

    #         # TensorBoard에 이미지 기록
    #         self.writer.add_image("Generated/Deblurred_Image", generated_image.squeeze(), epoch)

    def generate_sample(self, prompt, input_image, **kwargs):
        return generate(prompt, input_image, models=self.models, tokenizer=self.tokenizer, device=self.device, **kwargs)    
    def log_generated_images(self, epoch, blur_image):
        # 샘플 블러 이미지를 인자로 받아서 디블러링 결과 생성 및 기록
        with torch.no_grad():
            # "Deblur image" 프롬프트를 사용하여 디블러링 결과 생성
            prompt = "Deblur image"
            uncond_prompt = ""  # Also known as negative prompt
            do_cfg = True
            cfg_scale = 8  # min: 1, max: 14 prompt에 집중하는 정도
            strength = 1.0

            ## SAMPLER

            sampler = "ddpm"
            num_inference_steps = 50
            seed = 42


            # # Check if blur_image is a Tensor and preprocess it
            # if isinstance(blur_image, torch.Tensor):
            #     # Ensure the tensor is in CHW format before conversion
            #     blur_image = blur_image.squeeze()  # Remove unnecessary dimensions
            #     blur_image = F.to_pil_image(blur_image)  # Convert to PIL image
            # generated_image = self.generate_sample(prompt=prompt, input_image=blur_image)
            generated_image = pipeline.generate(
                prompt=prompt,
                uncond_prompt = uncond_prompt,
                input_image=blur_image,
                strength=strength,
                do_cfg=do_cfg,
                cfg_scale=cfg_scale,
                sampler_name=sampler,
                n_inference_steps=num_inference_steps,
                seed=seed,
                models=model.module,
                device=device,
                idle_device="cpu",
                tokenizer=tokenizer,
            )         

            # TensorBoard에 기록하기 위한 이미지 스케일 조정
            blur_image = rescale(torch.tensor(blur_image), (0, 255), (0, 1)).unsqueeze(0)
            generated_image = rescale(torch.tensor(generated_image), (0, 255), (0, 1)).unsqueeze(0)
            
            # TensorBoard에 디블러링된 이미지 기록
            self.writer.add_image("Input/Blurred_Image", blur_image, epoch)
            self.writer.add_image("Generated/Deblurred_Image", generated_image, epoch)




In [6]:
# image_path = "../images/blur_image.png"
# input_image = Image.open(image_path)

# fine_tuner = LDMFineTuner(models, tokenizer, device)
# fine_tuner.log_generated_images(1, input_image)


In [7]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# 학습
fine_tuner = LDMFineTuner(models, tokenizer, device)
fine_tuner.train(train_loader, num_epochs=100)

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


Epoch 1/100:   0%|          | 0/1423 [00:00<?, ?it/s]

Epoch 1/100: 100%|██████████| 1423/1423 [22:30<00:00,  1.05it/s]


Epoch 1/100, Average Loss: 1.8771
Model and optimizer saved at epoch 1 in ./checkpoints


Epoch 2/100: 100%|██████████| 1423/1423 [22:43<00:00,  1.04it/s]


Epoch 2/100, Average Loss: 1.8495
Model and optimizer saved at epoch 2 in ./checkpoints


Epoch 3/100: 100%|██████████| 1423/1423 [23:47<00:00,  1.00s/it]


Epoch 3/100, Average Loss: 1.8207
Model and optimizer saved at epoch 3 in ./checkpoints


Epoch 4/100: 100%|██████████| 1423/1423 [22:39<00:00,  1.05it/s]


Epoch 4/100, Average Loss: 1.7935
Model and optimizer saved at epoch 4 in ./checkpoints


Epoch 5/100: 100%|██████████| 1423/1423 [22:32<00:00,  1.05it/s]


Epoch 5/100, Average Loss: 1.7674
Model and optimizer saved at epoch 5 in ./checkpoints


Epoch 6/100: 100%|██████████| 1423/1423 [22:34<00:00,  1.05it/s]


Epoch 6/100, Average Loss: 1.7406
Model and optimizer saved at epoch 6 in ./checkpoints


Epoch 7/100: 100%|██████████| 1423/1423 [22:38<00:00,  1.05it/s]


Epoch 7/100, Average Loss: 1.7176
Model and optimizer saved at epoch 7 in ./checkpoints


Epoch 8/100: 100%|██████████| 1423/1423 [22:35<00:00,  1.05it/s]


Epoch 8/100, Average Loss: 1.6854
Model and optimizer saved at epoch 8 in ./checkpoints


Epoch 9/100: 100%|██████████| 1423/1423 [22:54<00:00,  1.03it/s]


Epoch 9/100, Average Loss: 1.6530
Model and optimizer saved at epoch 9 in ./checkpoints


Epoch 10/100: 100%|██████████| 1423/1423 [23:03<00:00,  1.03it/s]


Epoch 10/100, Average Loss: 1.6102
Model and optimizer saved at epoch 10 in ./checkpoints


Epoch 11/100: 100%|██████████| 1423/1423 [23:03<00:00,  1.03it/s]


Epoch 11/100, Average Loss: 1.5605
Model and optimizer saved at epoch 11 in ./checkpoints


Epoch 12/100: 100%|██████████| 1423/1423 [23:01<00:00,  1.03it/s]


Epoch 12/100, Average Loss: 1.4976
Model and optimizer saved at epoch 12 in ./checkpoints


Epoch 13/100: 100%|██████████| 1423/1423 [23:06<00:00,  1.03it/s]


Epoch 13/100, Average Loss: 1.4260
Model and optimizer saved at epoch 13 in ./checkpoints


Epoch 14/100: 100%|██████████| 1423/1423 [23:13<00:00,  1.02it/s]


Epoch 14/100, Average Loss: 1.3506
Model and optimizer saved at epoch 14 in ./checkpoints


Epoch 15/100: 100%|██████████| 1423/1423 [23:05<00:00,  1.03it/s]


Epoch 15/100, Average Loss: 1.2909
Model and optimizer saved at epoch 15 in ./checkpoints


Epoch 16/100: 100%|██████████| 1423/1423 [22:59<00:00,  1.03it/s]


Epoch 16/100, Average Loss: 1.2425
Model and optimizer saved at epoch 16 in ./checkpoints


Epoch 17/100: 100%|██████████| 1423/1423 [24:10<00:00,  1.02s/it]


Epoch 17/100, Average Loss: 1.2056
Model and optimizer saved at epoch 17 in ./checkpoints


Epoch 18/100: 100%|██████████| 1423/1423 [23:01<00:00,  1.03it/s]


Epoch 18/100, Average Loss: 1.1750
Model and optimizer saved at epoch 18 in ./checkpoints


Epoch 19/100: 100%|██████████| 1423/1423 [23:02<00:00,  1.03it/s]


Epoch 19/100, Average Loss: 1.1570
Model and optimizer saved at epoch 19 in ./checkpoints


Epoch 20/100: 100%|██████████| 1423/1423 [22:43<00:00,  1.04it/s]


Epoch 20/100, Average Loss: 1.1351
Model and optimizer saved at epoch 20 in ./checkpoints


Epoch 21/100: 100%|██████████| 1423/1423 [22:29<00:00,  1.05it/s]


Epoch 21/100, Average Loss: 1.1197
Model and optimizer saved at epoch 21 in ./checkpoints


Epoch 22/100: 100%|██████████| 1423/1423 [22:28<00:00,  1.05it/s]


Epoch 22/100, Average Loss: 1.1109
Model and optimizer saved at epoch 22 in ./checkpoints


Epoch 23/100: 100%|██████████| 1423/1423 [22:40<00:00,  1.05it/s]


Epoch 23/100, Average Loss: 1.1005
Model and optimizer saved at epoch 23 in ./checkpoints


Epoch 24/100: 100%|██████████| 1423/1423 [23:03<00:00,  1.03it/s]


Epoch 24/100, Average Loss: 1.0929
Model and optimizer saved at epoch 24 in ./checkpoints


Epoch 25/100: 100%|██████████| 1423/1423 [23:09<00:00,  1.02it/s]


Epoch 25/100, Average Loss: 1.0856
Model and optimizer saved at epoch 25 in ./checkpoints


Epoch 26/100: 100%|██████████| 1423/1423 [23:04<00:00,  1.03it/s]


Epoch 26/100, Average Loss: 1.0817
Model and optimizer saved at epoch 26 in ./checkpoints


Epoch 27/100: 100%|██████████| 1423/1423 [23:03<00:00,  1.03it/s]


Epoch 27/100, Average Loss: 1.0770
Model and optimizer saved at epoch 27 in ./checkpoints


Epoch 28/100: 100%|██████████| 1423/1423 [22:57<00:00,  1.03it/s]


Epoch 28/100, Average Loss: 1.0710
Model and optimizer saved at epoch 28 in ./checkpoints


Epoch 29/100: 100%|██████████| 1423/1423 [23:02<00:00,  1.03it/s]


Epoch 29/100, Average Loss: 1.0655
Model and optimizer saved at epoch 29 in ./checkpoints


Epoch 30/100: 100%|██████████| 1423/1423 [22:54<00:00,  1.04it/s]


Epoch 30/100, Average Loss: 1.0630
Model and optimizer saved at epoch 30 in ./checkpoints


Epoch 31/100: 100%|██████████| 1423/1423 [24:08<00:00,  1.02s/it]


Epoch 31/100, Average Loss: 1.0594
Model and optimizer saved at epoch 31 in ./checkpoints


Epoch 32/100: 100%|██████████| 1423/1423 [23:01<00:00,  1.03it/s]


Epoch 32/100, Average Loss: 1.0562
Model and optimizer saved at epoch 32 in ./checkpoints


Epoch 33/100: 100%|██████████| 1423/1423 [22:54<00:00,  1.03it/s]


Epoch 33/100, Average Loss: 1.0536
Model and optimizer saved at epoch 33 in ./checkpoints


Epoch 34/100: 100%|██████████| 1423/1423 [22:57<00:00,  1.03it/s]


Epoch 34/100, Average Loss: 1.0522
Model and optimizer saved at epoch 34 in ./checkpoints


Epoch 35/100: 100%|██████████| 1423/1423 [23:01<00:00,  1.03it/s]


Epoch 35/100, Average Loss: 1.0497
Model and optimizer saved at epoch 35 in ./checkpoints


Epoch 36/100: 100%|██████████| 1423/1423 [22:57<00:00,  1.03it/s]


Epoch 36/100, Average Loss: 1.0490
Model and optimizer saved at epoch 36 in ./checkpoints


Epoch 37/100: 100%|██████████| 1423/1423 [23:00<00:00,  1.03it/s]


Epoch 37/100, Average Loss: 1.0452
Model and optimizer saved at epoch 37 in ./checkpoints


Epoch 38/100: 100%|██████████| 1423/1423 [22:58<00:00,  1.03it/s]


Epoch 38/100, Average Loss: 1.0438
Model and optimizer saved at epoch 38 in ./checkpoints


Epoch 39/100: 100%|██████████| 1423/1423 [22:59<00:00,  1.03it/s]


Epoch 39/100, Average Loss: 1.0417
Model and optimizer saved at epoch 39 in ./checkpoints


Epoch 40/100: 100%|██████████| 1423/1423 [23:21<00:00,  1.02it/s]


Epoch 40/100, Average Loss: 1.0412
Model and optimizer saved at epoch 40 in ./checkpoints


Epoch 41/100: 100%|██████████| 1423/1423 [23:07<00:00,  1.03it/s]


Epoch 41/100, Average Loss: 1.0397
Model and optimizer saved at epoch 41 in ./checkpoints


Epoch 42/100: 100%|██████████| 1423/1423 [23:10<00:00,  1.02it/s]


Epoch 42/100, Average Loss: 1.0382
Model and optimizer saved at epoch 42 in ./checkpoints


Epoch 43/100: 100%|██████████| 1423/1423 [23:06<00:00,  1.03it/s]


Epoch 43/100, Average Loss: 1.0374
Model and optimizer saved at epoch 43 in ./checkpoints


Epoch 44/100: 100%|██████████| 1423/1423 [23:02<00:00,  1.03it/s]


Epoch 44/100, Average Loss: 1.0363
Model and optimizer saved at epoch 44 in ./checkpoints


Epoch 45/100: 100%|██████████| 1423/1423 [22:51<00:00,  1.04it/s]


Epoch 45/100, Average Loss: 1.0356
Model and optimizer saved at epoch 45 in ./checkpoints


Epoch 46/100: 100%|██████████| 1423/1423 [23:35<00:00,  1.01it/s]


Epoch 46/100, Average Loss: 1.0347
Model and optimizer saved at epoch 46 in ./checkpoints


Epoch 47/100: 100%|██████████| 1423/1423 [22:34<00:00,  1.05it/s]


Epoch 47/100, Average Loss: 1.0331
Model and optimizer saved at epoch 47 in ./checkpoints


Epoch 48/100: 100%|██████████| 1423/1423 [22:43<00:00,  1.04it/s]


Epoch 48/100, Average Loss: 1.0331
Model and optimizer saved at epoch 48 in ./checkpoints


Epoch 49/100: 100%|██████████| 1423/1423 [22:54<00:00,  1.04it/s]


Epoch 49/100, Average Loss: 1.0325
Model and optimizer saved at epoch 49 in ./checkpoints


Epoch 50/100: 100%|██████████| 1423/1423 [23:00<00:00,  1.03it/s]


Epoch 50/100, Average Loss: 1.0316
Model and optimizer saved at epoch 50 in ./checkpoints


Epoch 51/100: 100%|██████████| 1423/1423 [23:04<00:00,  1.03it/s]


Epoch 51/100, Average Loss: 1.0304
Model and optimizer saved at epoch 51 in ./checkpoints


Epoch 52/100: 100%|██████████| 1423/1423 [23:03<00:00,  1.03it/s]


Epoch 52/100, Average Loss: 1.0309
Model and optimizer saved at epoch 52 in ./checkpoints


Epoch 53/100: 100%|██████████| 1423/1423 [23:09<00:00,  1.02it/s]


Epoch 53/100, Average Loss: 1.0299
Model and optimizer saved at epoch 53 in ./checkpoints


Epoch 54/100: 100%|██████████| 1423/1423 [23:05<00:00,  1.03it/s]


Epoch 54/100, Average Loss: 1.0290
Model and optimizer saved at epoch 54 in ./checkpoints


Epoch 55/100: 100%|██████████| 1423/1423 [23:11<00:00,  1.02it/s]


Epoch 55/100, Average Loss: 1.0286
Model and optimizer saved at epoch 55 in ./checkpoints


Epoch 56/100: 100%|██████████| 1423/1423 [23:04<00:00,  1.03it/s]


Epoch 56/100, Average Loss: 1.0274
Model and optimizer saved at epoch 56 in ./checkpoints


Epoch 57/100:  92%|█████████▏| 1305/1423 [21:18<01:56,  1.01it/s]