In [27]:
import torch
from torch import nn
from torch.nn import functional as F

class Patches(nn.Module):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def forward(self, images):
        batch_size, channels, height, width = images.shape
        patches = images.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
        patches = patches.view(batch_size, -1, channels * self.patch_size * self.patch_size)
        #print(patches.shape)
        #patches = patches.permute(0,2,1)
        return patches


class PatchEncoder(nn.Module):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = nn.Linear(768, projection_dim)
        self.position_embedding = nn.Embedding(num_patches, projection_dim)

    def forward(self, patches):
        positions = torch.arange(0, self.num_patches).to(device)
        ddd = self.projection(patches)
        #print(ddd.is_cuda)
        eee = self.position_embedding(positions)
        #print(eee.is_cuda)
        encoded = ddd + eee
        return encoded


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = nn.MultiheadAttention(embed_dim, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim),
        )
        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.layernorm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(rate)
        self.dropout2 = nn.Dropout(rate)

    def forward(self, inputs):
        attn_output, _ = self.att(inputs, inputs, inputs)
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=False):
        super(ResidualBlock, self).__init__()
        self.downsample = downsample
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if downsample else 1, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        if downsample:
            self.downsample_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2)

    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample:
            identity = self.downsample_conv(x)
        out += identity
        out = F.relu(out)
        return out


class Generator(nn.Module):
    def __init__(self, patch_size, num_patches, projection_dim, num_heads, ff_dim):
        super(Generator, self).__init__()
        self.patches = Patches(patch_size)
        self.patch_encoder = PatchEncoder(num_patches, projection_dim)
        self.transformer_blocks = nn.Sequential(
            TransformerBlock(64, num_heads, ff_dim),
            TransformerBlock(64, num_heads, ff_dim),
            TransformerBlock(64, num_heads, ff_dim),
            TransformerBlock(64, num_heads, ff_dim),
        )
        self.deconv_blocks = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            ResidualBlock(512, 512),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            ResidualBlock(256, 256),
            nn.ConvTranspose2d(256, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            ResidualBlock(64, 64),
            nn.ConvTranspose2d(64, 32, kernel_size=5, stride=4, padding=2, output_padding=3),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            ResidualBlock(32, 32),
            nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(3),
            nn.LeakyReLU(),
            ResidualBlock(3, 3),
            nn.ConvTranspose2d(3,3,(2,2),(2,2)),
            nn.Tanh(),
        )

    def forward(self, x):
        x = self.patches(x)
        x = self.patch_encoder(x)
        x = self.transformer_blocks(x)
        #print(x.shape)
        x = x.view(x.size(0), 1024, 4, 4)
        x = self.deconv_blocks(x)
        return x
#a = torch.rand((1,3,256,256))
#generator = Generator(16,256,64,2,32)
#generator(a).shape

In [28]:
# 读取数据集
import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
from matplotlib import pyplot as plt

class ImageDataset(Dataset):
    def __init__(self, root_dir1, root_dir2, transform=None):
        self.root_dir1 = root_dir1
        self.root_dir2 = root_dir2
        self.transform = transform
        self.image1_paths = sorted(os.listdir(root_dir1))[:10000]
        self.image2_paths = sorted(os.listdir(root_dir2))[:10000]
        assert self.image1_paths == self.image2_paths
        
    def __len__(self):
        return len(self.image1_paths)

    def __getitem__(self, idx):
        img_name1 = os.path.join(self.root_dir1, self.image1_paths[idx])
        image1 = Image.open(img_name1).convert("RGB")
        
        img_name2 = os.path.join(self.root_dir2, self.image2_paths[idx])
        image2 = Image.open(img_name2).convert("RGB")
        if self.transform:
            image1 = self.transform(image1)
            image2 = self.transform(image2)
        return image1, image2

# 示例用法
# 1. 创建转换
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# 2. 创建数据集实例
root_dir1 = "D:\\NLOS\\DataSet\\STL10\\GT_stl10_allimages"
root_dir2 = "D:\\NLOS\\DataSet\\STL10\\pro"

dataset = ImageDataset(root_dir1, root_dir2, transform=transform)

# 3. 通过索引访问数据集中的图像
#image1, image2 = dataset[0]  # 获取第一个图像
#print(image2.shape)

#plt.imshow(image1.permute(1,2,0))
#plt.show

In [29]:
#准备训练
import torch
import torch.nn as nn
import torch.optim as optim


device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model = Generator(16,256,64,2,32).to(device)
train_dataloader = DataLoader(dataset,batch_size = 8)
print(len(train_dataloader))
# 设置一些参数
num_epochs = 100
learning_rate = 0.001

# 准备模型、数据集和数据加载器（假设已经准备好）


# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)



cuda
1250


In [30]:
# 开始训练
for epoch in range(num_epochs):
    # 设置模型为训练模式
    model.train()
    
    # 遍历数据加载器
    for batch_data in train_dataloader:
        # 将数据移动到设备上（例如GPU）
        inputs, labels = batch_data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # 前向传播
        outputs = model(inputs)
        
        # 计算损失
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # 打印每个epoch的损失
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


KeyboardInterrupt: 