<a href="https://colab.research.google.com/github/zoujiulong/Multimodal/blob/main/Score_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q omegaconf torchvision tqdm matplotlib

In [None]:
import math
import torch
import torch.nn as nn

# class TimeEmedding(nn.Module):
#     def __init__(self,in_channel):
#         super().__init__()
#         self.in_channel=in_channel
#         self.proj1=nn.Linear(in_channel//4,in_channel)
#         self.act1=nn.ELU()
#         self.proj2=nn.Linear(in_channel,in_channel)
#     def forward(self,t):
#         print('t',t.shape)
#         # t_emb=torch.empty((*t.shape[:-1],self.in_channel//4))
#         print('t_emb',t_emb.shape)
#         emb=math.log(10000)/(self.in_channel//8)
#         emb=t[:,None]*torch.exp(torch.arange(self.in_channel//8)*-emb)
#         print('emb',emb.shape)
#         t_emb[:,:,:,::2]=emb.sin()
#         t_emb[:,:,:,1::2]=emb.cos()
#         t_emb=self.proj1(t_emb)
#         t_emb=self.act1(t_emb)
#         t_emb=self.proj2(t_emb)
#         return t_emb
class Swish(nn.Module):
    """
    ### Swish activation function

    $$x \cdot \sigma(x)$$
    """

    def forward(self, x):
        return x * torch.sigmoid(x)
class TimeEmbedding(nn.Module):
    """
    ### Embeddings for $t$
    """

    def __init__(self, n_channels: int):
        """
        * `n_channels` is the number of dimensions in the embedding
        """
        super().__init__()
        self.n_channels = n_channels
        # First linear layer
        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
        # Activation
        self.act = Swish()
        # Second linear layer
        self.lin2 = nn.Linear(self.n_channels, self.n_channels)

    def forward(self, t: torch.Tensor):
        # Create sinusoidal position embeddings
        # [same as those from the transformer](../../transformers/positional_encoding.html)
        #
        # \begin{align}
        # PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\
        # PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg)
        # \end{align}
        #
        # where $d$ is `half_dim`
        half_dim = self.n_channels // 8
        emb = math.log(10_000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=1)

        # Transform with the MLP
        emb = self.act(self.lin1(emb))
        emb = self.lin2(emb)

        #
        return emb

class Attention(nn.Module):
    def __init__(self,head_num,ch,p=0.1):
        super().__init__()
        self.head_num=head_num
        self.head_dim=ch//head_num
        self.qkv=nn.Linear(ch,3*ch)
        self.softmax=nn.Softmax()
        self.drop1=nn.Dropout(p)
        self.linear=nn.Linear(ch,ch)
        self.drop2=nn.Dropout(p)
    def forward(self,x):
        bs,ch,h,w=x.shape
        out=x.view(bs,ch,-1).permute(0,2,1)
        out=self.qkv(out).view(bs,-1,self.head_num,3*self.head_dim).permute(0,2,1,3)
        q,k,v=torch.chunk(out,3,dim=-1)
        k=torch.transpose(k,-2,-1)
        inter=(q@k)/(self.head_dim)**0.5
        inter=self.drop1(self.softmax(inter))
        o=inter@v
        o=o.permute(0,2,1,3).contiguous()
        o=o.view(bs,h*w,-1)
        o=self.drop2(self.linear(o))
        o=o.permute(0,2,1).view(bs,-1,h,w)+x
        return o

class ResidualBlock(nn.Module):
    def __init__(self,in_channel,out_channel,time_channel,n_groups=32,p=0.1):
        super().__init__()
        self.in_ch=in_channel
        # self.ins1=nn.InstanceNorm2d(in_channel)
        self.ins1=nn.GroupNorm(n_groups, in_channel)
        # self.act1=nn.ELU()
        self.act1=Swish()
        self.conv1=nn.Conv2d(in_channel,out_channel,kernel_size=(3,3),padding=1)
        # self.ins2=nn.InstanceNorm2d(in_channel)
        self.ins2=nn.GroupNorm(n_groups, out_channel)
        self.act2=Swish()
        self.conv2=nn.Conv2d(out_channel,out_channel,kernel_size=(3,3),padding=1)
        if in_channel!=out_channel:
            self.residual=nn.Conv2d(in_channel,out_channel,kernel_size=(1,1))
        else:
            self.residual=nn.Identity()
        self.time_emb = nn.Linear(time_channel, out_channel)
        self.time_act = Swish()
        self.dropout = nn.Dropout(p)
    def forward(self,img,t):
        out=self.conv1(self.act1(self.ins1(img)))
        print('out',out.shape)
        # out+=self.time_emb(self.time_act(t)).permute(0,3,1,2)
        if t is not None:
          out+=self.time_emb(self.time_act(t))[:, :, None, None]
        out=self.conv2(self.dropout(self.act2(self.ins2(out))))
        out+=self.residual(img)
        return out

class DownBlock(nn.Module):
    def __init__(self,in_channel,out_channel,time_channel,head_num,has_attention=False):
        super().__init__()
        self.residual=ResidualBlock(in_channel,out_channel,time_channel)
        if has_attention:
            self.attention=Attention(head_num,out_channel)
        else:
            self.attention=nn.Identity()
    def forward(self,img,t):
        out=self.residual(img,t)
        out=self.attention(out)
        return out

class MiddleBlock(nn.Module):
    def __init__(self,in_channel,out_channel,time_channel,head_num):
        super().__init__()
        self.residual1=ResidualBlock(in_channel,out_channel,time_channel)
        self.attention=Attention(head_num,out_channel)
        self.residual2=ResidualBlock(out_channel,out_channel,time_channel)
    def forward(self,img,t):
        out=self.residual1(img,t)
        out=self.attention(out)
        print('attention')
        out=self.residual2(out,t)
        return out

class UpBlock(nn.Module):
    def __init__(self,in_channel,out_channel,time_channel,head_num,has_attention=False):
        super().__init__()
        self.residual=ResidualBlock(in_channel+out_channel,out_channel,time_channel)
        if has_attention:
            self.attention=Attention(head_num,out_channel)
        else:
            self.attention=nn.Identity()
    def forward(self,img,t):
        out=self.residual(img,t)
        out=self.attention(out)
        return out

class DownSampleBlock(nn.Module):
    def __init__(self,in_channel):
        super().__init__()
        self.conv=nn.Conv2d(in_channel,in_channel,kernel_size=(3,3),stride=(2,2),padding=(1,1))
    def forward(self,img,t):
        _=t
        return self.conv(img)

class UpSampleBlock(nn.Module):
    def __init__(self,in_channel):
        super().__init__()
        self.conv=nn.ConvTranspose2d(in_channel,in_channel,kernel_size=(4,4),stride=(2,2),padding=(1,1))
    def forward(self,img,t):
        _=t
        return self.conv(img)

class Unet(nn.Module):
    def __init__(self,layers,img_ch=1,in_channel=64,blocks=2,head_num=8):
        super().__init__()
        down=[]
        coe=(1,2,2,4)
        self.conv1=nn.Conv2d(img_ch,in_channel,kernel_size=(3,3),padding=(1,1))
        self.time=TimeEmbedding(in_channel*4)
        n_channel=in_channel
        for i in range(layers):
            out_ch=coe[i]*in_channel
            # print(out_ch)
            # print(in_channel)
            for _ in range(blocks):
                down.append(DownBlock(in_channel,out_ch,n_channel*4,head_num))
                in_channel=out_ch
            # print(in_channel)
            # print(out_ch)
            if i<layers-1:
                down.append(DownSampleBlock(out_ch))
        self.down=nn.ModuleList(down)
        self.middle=MiddleBlock(out_ch,out_ch,n_channel*4,head_num)
        up=[]
        in_channel=out_ch
        for i in range(layers):
            out_ch=in_channel
            for _ in range(blocks):
                up.append(UpBlock(in_channel,out_ch,n_channel*4,head_num))
            out_ch=in_channel//coe[layers-1-i]
            up.append(UpBlock(in_channel,out_ch,n_channel*4,head_num))
            in_channel=out_ch
            if i<layers-1:
                up.append(UpSampleBlock(in_channel))
        self.up=nn.ModuleList(up)
        # self.ins=nn.InstanceNorm2d(in_channel)
        self.ins=nn.GroupNorm(8, n_channel)

        # self.act=nn.ELU()
        self.act=Swish()
        self.conv2=nn.Conv2d(in_channel,img_ch,kernel_size=(3,3),padding=(1,1))
    def forward(self,x,t):
        out=self.conv1(x)
        t=self.time(t)
        record=[out]
        for down in self.down:
            out=down(out,t)
            print('out')
            record.append(out)
        out=self.middle(out,t)
        for up in self.up:
            if isinstance(up,UpSampleBlock):
                out=up(out,t)
            else:
                out=torch.cat((out,record.pop()),dim=1)
                out=up(out,t)
        out=self.conv2(self.act(self.ins(out)))
        return out

In [None]:
class LinearWarmupStepLRScheduler:
    def __init__(self,optimizer,epochs,min_lr,init_lr,decay_rate,warmup_start_lr,warmup_steps):
        self.optimizer=optimizer
        self.epochs=epochs
        self.decay_rate=decay_rate
        self.init_lr=init_lr
        self.min_lr=min_lr
        self.warmup_start_lr=warmup_start_lr
        self.warmup_steps=warmup_steps

    def step(self,cur_epoch,cur_step):
        if cur_epoch==0:
            self.warmup_lr_schedule(
                step=cur_step,
                optimizer=self.optimizer,
                max_step=self.warmup_steps,
                init_lr=self.warmup_start_lr,
                max_lr=self.init_lr,
            )
        else:
            self.step_lr_schedule(
                epoch=cur_epoch,
                optimizer=self.optimizer,
                init_lr=self.init_lr,
                min_lr=self.min_lr,
                decay_rate=self.decay_rate
            )

    def warmup_lr_schedule(self,optimizer, step, max_step, init_lr, max_lr):
        """Warmup the learning rate"""
        lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

    def step_lr_schedule(self,epoch,optimizer,init_lr,min_lr,decay_rate):
        lr=max(min_lr,init_lr*(decay_rate**epoch))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

In [None]:
from omegaconf import OmegaConf
import logging
import numpy as np
import torch.nn.functional as F
from scipy import linalg
from torchvision import datasets, transforms
from PIL import Image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    force=True  # 强制重新设置 logging 配置（Jupyter 已设置过一次）
)
from torch.utils.data import DataLoader,Dataset

def padding(imgs,max_h=1024,max_w=1024,mode='constant'):
    pad_img=[]
    for i,img in enumerate(imgs):
        try:
            _, h, w = img.shape
            pad_h = (max_h - h % max_h) % max_h
            pad_w = (max_w - w % max_w) % max_w
            pad_top = pad_h // 2
            pad_bottom = pad_h - pad_top
            pad_left = pad_w // 2
            pad_right = pad_w - pad_left
            padded = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom), mode=mode)
            mask=torch.ones_like(padded)
            if pad_left>0:
                for i in range(pad_left):
                    alpha=i/pad_left
                    padded[:,:,i]=alpha*padded[:,:,pad_left]
            if pad_right>0:
                for i in range(pad_right):
                    alpha=i/pad_right
                    padded[:,:,-(i+1)]=alpha*padded[:,:,-(pad_right+1)]
            if pad_top>0:
                for i in range(pad_top):
                    alpha=i/pad_top
                    padded[:,i,pad_left:-(pad_right+1)]=alpha*padded[:,pad_top,pad_left:-(pad_right+1)]
                    padded[:,i,:pad_left]*=alpha*padded[:,pad_top,:pad_left]
                    padded[:,i,-(pad_right+1):]*=alpha*padded[:,pad_top,-(pad_right+1):]
            if pad_bottom>0:
                for i in range(pad_bottom):
                    alpha=i/pad_bottom
                    padded[:,-(i+1),pad_left:-(pad_right+1)]=alpha*padded[:,-(pad_bottom+1),pad_left:-(pad_right+1)]
                    padded[:,-(i+1),:pad_left]*=alpha*padded[:,-(pad_bottom+1),:pad_left]
                    padded[:,-(i+1),-(pad_right+1):]*=alpha*padded[:,-(pad_bottom+1),-(pad_right+1):]
            pad_img.append(padded)
        except Exception as e:
            print(f"Error in padding img {i}: {e}")
    print(len(pad_img))
    return torch.stack(pad_img)

def custom_collate_fn(batch):
    images=[x[0] for x in batch]
    return padding(images)

class DataSet(Dataset):
    def __init__(self,data):
        self.data=data
    def __getitem__(self,index):
        return self.data[index]
    def __len__(self,):
        return len(self,data)

class Train:
    def __init__(self,conf_path,img_path,model):
        self.model_config=OmegaConf.load(conf_path)
        self.optim=self.optimizer(model)
        self.model=model
        self.img_path=img_path
    def get_optimizer_params(self,model,weight_decay,lr_scale=1):
        p_wd,p_non_wd=[],[]
        for n,p in model.named_parameters():
            if p.ndim<2 or 'bias' in n:
                p_non_wd.append(p)
            else:
                p_wd.append(p)
        optim_params=[{
            "params":p_wd,"weight_decay":weight_decay,"lr_scale":lr_scale
        },{
            "params":p_non_wd,"weight_decay":0,"lr_scale":lr_scale
        }]
        return optim_params

    def optimizer(self,model):
        optim_params=self.get_optimizer_params(model,self.model_config.run.weight_decay)
        num_parameters=0
        for p_group in optim_params:
            for p in p_group['params']:
                num_parameters+=p.data.nelement()
        logging.info('number of trainable parameters: {}'.format(num_parameters))
        optimizer=torch.optim.AdamW(optim_params,lr=self.model_config.run.init_lr,betas=(self.model_config.run.beta1,self.model_config.run.beta2))
        return optimizer

    def create_loader(self):
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3,std=[0.5]*3)
        ])
        logging.info('start transform')
        dataset=datasets.ImageFolder(self.img_path,transform=transform)
        logging.info('finish transform')
        bs=self.model_config.run.batch_size
        num_workers=self.model_config.run.num_workers
        loader=iter(DataLoader(dataset,batch_size=bs,collate_fn=custom_collate_fn,num_workers=num_workers,pin_memory=True))
        return loader

    def lr_scheduler(self,epochs):
        min_lr=self.model_config.run.min_lr
        init_lr=self.model_config.run.init_lr
        decay_rate=self.model_config.run.lr_decay_rate
        warmup_start_lr=self.model_config.run.warmup_lr
        warmup_steps=self.model_config.run.warmup_steps
        lr_sched = LinearWarmupStepLRScheduler(
            optimizer=self.optimizer(self.model),
            epochs=epochs,
            min_lr=self.model_config.run.min_lr,
            init_lr=self.model_config.run.init_lr,
            decay_rate=self.model_config.run.lr_decay_rate,
            warmup_start_lr=self.model_config.run.warmup_lr,
            warmup_steps=self.model_config.run.warmup_steps,
        )
        return lr_sched

    def train_epoch(self,epoch,iter_per_epoch,data_loader,lr_scheduler):
        loss_list=[]
        for i in tqdm(range(iter_per_epoch),desc='iter'):
            samples=next(data_loader)
            samples=samples
            print(type(lr_scheduler))
            lr_scheduler.step(epoch,i)
            loss=self.denoise_score(samples,1,0.01)
            loss.backward()
            self.optim.step()
            self.optim.zero_grad()
            loss_value = loss.item() if hasattr(loss, "item") else float(loss)
            loss_list.append(loss_value)  # 记录 loss
            logging.info("cur_epoch:{0},cur_iter:{1},loss:{2}".format(epoch,i,loss_value))
        return loss_list
    def train(self,epochs):
        self.model.train()
        all_losses = []
        for epoch in tqdm(range(epochs),desc="epoch"):
            data_loader=self.create_loader()
            logging.info("Start training")
            epoch_loss=self.train_epoch(epoch,len(data_loader),data_loader,self.lr_scheduler(epochs))
            x=self.anneal_langevin_dynamics((1,3,1024,1024))
            all_losses.extend(epoch_loss)
            img = x.squeeze(0).permute(1,2,0).detach().cpu().numpy()
            plt.imshow(img)
            plt.show()
        plt.figure(figsize=(10, 4))
        plt.plot(all_losses)
        plt.title("Loss Curve")
        plt.xlabel("Iteration")
        plt.ylabel("Loss")
        plt.grid(True)
        plt.show()

    def denoise_score(self,x,init_sigma=1,end_sigma=0.01,L=10):

        sigma_seq=torch.logspace(math.log10(init_sigma),math.log10(end_sigma),L)
        bs=len(x)
        idx=torch.randint(0,L,(bs,))
        sigma=sigma_seq[idx].view(-1,1,1,1).to(x.device)
        noise=sigma*torch.randn_like(x)
        x_=x+noise
        loss=0.5*((model(x_,sigma)+noise/sigma**2)**2).view(bs,-1).sum(dim=-1).mean()*sigma.squeeze()**2
        return loss.mean()
    #
    def get_inception_score(self,dataloader,inception_model,splits):
        inception_model.eval()
        preds=[]
        for batch in dataloader:
            logits=inception_model(batch)
            p_yx=F.softmax(logits,dim=-1)
            preds.append(p_yx.cpu().numpy())
        preds=np.concatenate(preds,axis=0)
        N=preds.shape[0]
        scores=[]
        for i in range(splits):
            part=preds[(i*preds.shape[0]//splits):((i+1)*preds.shape[0]//splits),:]
            kl=part*(np.log(part)-np.log(np.expand_dims(np.mean(part,dim=0),0)))
            kl=np.mean(np.sum(kl,1))
            score.append(np.exp(kl))
        return np.mean(score),np.std(score)

    def get_fid(self,inception_model,real_img,generate_img):
        real_f=inception_model(real_img)
        generate_f=inception_model(generate_img)
        mu1,sigma1=np.mean(real_f,axis=0),np.cov(real_f,rowvar=False)
        mu2,sigma2=np.mean(generate_f,axis=0),np.cov(generate_f,rowvar=False)
        diff=mu1-mu2
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if np.iscomplexobj(covmean):
            covmean=covmean.real
        fid=diff.dot(diff)+np.trace(sigma1+sigma2-2*covmean)
        return fid

    def anneal_langevin_dynamics(self,shape,low=0,high=1,init_sigma=1,end_sigma=0.01,L=10,T=1,eps=2e-5):
        x_t=(high-low)*torch.rand(shape)+low
        sigma_seq = torch.logspace(math.log10(1), math.log10(0.01), 10)

        for i in range(L):
            sigma=sigma_seq[i]
            alpha=eps*(sigma/end_sigma)**2
            for t in range(T):
                z = torch.randn_like(x_t)
                # noise = sigma * torch.randn_like(x_t)
                # x_in = x_t + noise
                score = self.model(x_in,torch.tensor(sigma).view(1,1,1,1))  # 如果模型是 noise-conditional
                x_t = x_t + alpha * score + torch.sqrt(alpha) * z
        return x_t

    def get_max_hw(self):
        hm, wm = float('-inf'), float('-inf')
        root = self.img_path  # dataset 根目录
        for class_name in os.listdir(root):
            class_path = os.path.join(root, class_name)
            if not os.path.isdir(class_path):
                continue
            for fname in os.listdir(class_path):
                fpath = os.path.join(class_path, fname)
                try:
                    with Image.open(fpath) as img:
                        w, h = img.size  # 注意顺序：w,h
                        hm = max(hm, h)
                        wm = max(wm, w)
                except Exception as e:
                    print(f"跳过无法识别的文件: {fpath}")

        print("最大高:", hm)
        print("最大宽:", wm)

    def save_checkpoint(self):
        pass

In [None]:
model=Unet(3)
conf_path='/content/drive/MyDrive/config.yaml'
img_path='/content/drive/MyDrive/filter/'
train=Train(conf_path,img_path,model)
train.train(50)