In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader

import torchvision.models as models
import torchvision.transforms as transforms
from torchvision import transforms

import json

from PIL import Image

from transformers import CLIPProcessor, CLIPModel

import UNet_utils
import ddpm_utils

# Kiểm tra và thiết lập thiết bị
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [23]:
CLIP_FEATURES = 512
T = 400
IMG_CH = 3
IMG_SIZE = 64
model = UNet_utils.UNet(
    T, IMG_CH, IMG_SIZE, down_chs=(256, 256, 512), t_embed_dim=8, c_embed_dim=CLIP_FEATURES
)

In [24]:
IMG_SIZE = 64 # Due to stride and pooling, must be divisible by 2 multiple times
BATCH_SIZE = 32
INPUT_SIZE = (IMG_CH, IMG_SIZE, IMG_SIZE)

pre_transforms = [
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),  # Scales data into [0,1]
    transforms.Lambda(lambda t: (t * 2) - 1)  # Scale between [-1, 1]
]
pre_transforms = transforms.Compose(pre_transforms)
random_transforms = [
    transforms.RandomCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
]
random_transforms = transforms.Compose(random_transforms)

In [25]:
DIR = 'naruto_images/'
file_path = 'captions.json'
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model_hf = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
# Hàm khởi tạo embeding.
def clip_encode(text, processor, device): 
    inputs = processor(
        text=text, 
        return_tensors="pt", 
        padding="max_length",
        truncation=True, 
        max_length=20, 
        return_attention_mask=True
    ).to(device)

    with torch.no_grad():
        outputs = clip_model_hf.text_model(**inputs)

    return outputs.last_hidden_state * inputs["attention_mask"].unsqueeze(-1) , outputs.pooler_output  # (B, 512)
class MyDataset(Dataset):
    def __init__(self, file_path, preprocessed_clip=True):
        self.imgs = []
        self.labelsT = []  # (B,T,D)
        self.labels = []   # (B,D)
        self.preprocessed_clip = preprocessed_clip

        with open(file_path, 'r', encoding='utf-8') as file:
            reader = json.load(file)
            for item in reader:
                img = Image.open(DIR + item['image']).convert('RGB')
                self.imgs.append(pre_transforms(img).to(device))

                if preprocessed_clip:
                    embeding_T, embeding = clip_encode(item['caption'], processor, device)
                    self.labelsT.append(embeding_T)  # (B,T,D)
                    self.labels.append(embeding)    # (B,D)

    def __getitem__(self, idx):
        img = random_transforms(self.imgs[idx])
        #if self.preprocessed_clip:
        labelT = self.labelsT[idx]
        label = self.labels[idx]
        #else:
        # batch_img = img[None, :, :, :]
        # encoded_imgs = clip_model.encode_image(clip_preprocess(batch_img))
        # label = encoded_imgs.to(device).float()[0]  # (D)
        # labelT = encoded_imgs.to(device).float()[0]  # (D)
        
        return img, label, labelT

    def __len__(self):
        return len(self.imgs)

In [None]:
import threading
import time
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

stop_training = threading.Event()  # Cờ để dừng training
CLIP_FEATURES = 512
T = 400
IMG_CH = 3
IMG_SIZE = 64
INPUT_SIZE = (IMG_CH, IMG_SIZE, IMG_SIZE)

model = UNet_utils.UNet(
    T, IMG_CH, IMG_SIZE, down_chs=(256, 256, 512), t_embed_dim=8, c_embed_dim=CLIP_FEATURES
)
model = model.to(device)

train_data = MyDataset(file_path)
dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

lrate = 1e-5
optimizer = optim.Adam(model.parameters(), lr=lrate)

try:
    checkpoint = torch.load('checkpoint.pth', map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    current_epoch = checkpoint['epoch'] + 1
    print(f"Checkpoint loaded! Resuming from epoch {current_epoch}")
except FileNotFoundError:
    print("No checkpoint found. Starting from scratch.")
    current_epoch = 0

B_start = 0.0001
B_end = 0.02
B = torch.linspace(B_start, B_end, T).to(device)
ddpm = ddpm_utils.DDPM(B, device)

def get_context_mask(c, c_drop_prob=0.1):
    """Tạo mặt nạ ngữ cảnh"""
    c_mask = torch.bernoulli(torch.ones_like(c).float() - c_drop_prob).to(device)
    return c_mask

def timer_function(hours=11):
    """Hàm chạy bộ đếm thời gian"""
    time.sleep(hours * 3600)  # Đợi đúng số giờ
    stop_training.set()  # Đánh dấu dừng training
    print("Time's up! Training will stop soon.")

timer_thread = threading.Thread(target=timer_function, args=(0.01,))
timer_thread.start()

epochs = 5  # Số epoch muốn train tiếp
model.train()

for epoch in range(current_epoch, current_epoch + epochs):
    if stop_training.is_set():  # Kiểm tra nếu hết thời gian thì thoát
        print("Stopping training due to timeout.")
        break
    
    for step, batch in enumerate(dataloader):
        if stop_training.is_set():  # Kiểm tra trong vòng lặp batch
            print("Training stopped at step:", step)
            break

        model.zero_grad()
        t = torch.randint(0, T, (BATCH_SIZE,), device=device).float()
        x, c, c1 = batch
        c_mask = get_context_mask(c, c_drop_prob=0.1)
        loss = ddpm.get_loss(model, x, t, c, c1, c_mask)
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch} completed, loss: {loss.item()}")

# ====== Lưu checkpoint ======
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
}
torch.save(checkpoint, 'checkpoint.pth')
print("Checkpoint saved successfully!")

# ====== Đợi luồng đếm thời gian kết thúc ======
timer_thread.join()


  checkpoint = torch.load('checkpoint.pth', map_location=device)


Checkpoint loaded! Resuming from epoch 1
Epoch 1 completed, loss: 0.14790010452270508
Time's up! Training will stop soon.
Training stopped at step: 8
Epoch 2 completed, loss: 0.14373034238815308
Stopping training due to timeout.
Checkpoint saved successfully!
