In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import re
import zipfile
from collections import defaultdict

# ConvLSTMCell类定义
class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.conv = nn.Conv2d(in_channels=self.input_channels + self.hidden_channels,
                              out_channels=4 * self.hidden_channels,
                              kernel_size=self.kernel_size,
                              padding=self.padding)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state
        combined = torch.cat([input_tensor, h_cur], dim=1)
        combined_conv = self.conv(combined)
        combined_conv = F.elu(combined_conv)  # Applying ELU activation

        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_channels, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        device = next(self.parameters()).device
        return (torch.zeros(batch_size, self.hidden_channels, height, width, device=device),
                torch.zeros(batch_size, self.hidden_channels, height, width, device=device))

# ConvEncoder类定义
class ConvEncoder(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size, dropout_prob=0.7):
        super(ConvEncoder, self).__init__()
        self.conv_lstm1 = ConvLSTMCell(input_channels, hidden_channels // 2, kernel_size)

        self.batch_norm = nn.BatchNorm2d(hidden_channels // 2)
        self.dropout = nn.Dropout2d(dropout_prob)

    def forward(self, input_tensor):
        batch_size, seq_len, _, height, width = input_tensor.size()
        h, c = self.conv_lstm1.init_hidden(batch_size, (height, width))
        layer_output_list = []

        for t in range(seq_len):
            h, c = self.conv_lstm1(input_tensor[:, t, :, :, :], (h, c))
            h = self.batch_norm(h)
            h = self.dropout(h)
            layer_output_list.append(h)

        layer_output = torch.stack(layer_output_list, dim=1)
        return layer_output, (h, c)

# ConvDecoder类定义
class ConvDecoder(nn.Module):
    def __init__(self, hidden_channels, output_channels, kernel_size, dropout_prob=0.7):
        super(ConvDecoder, self).__init__()
        self.conv_lstm1 = ConvLSTMCell(hidden_channels // 2, hidden_channels // 2, kernel_size)
        self.batch_norm = nn.BatchNorm2d(hidden_channels // 2)
        self.conv = nn.Conv2d(hidden_channels // 2, output_channels, kernel_size=1)
        self.dropout = nn.Dropout2d(dropout_prob)

    def forward(self, encoder_output, h, c, seq_len):
        batch_size, seq_len_enc, hidden_channels, height, width = encoder_output.size()
        outputs = []

        for t in range(seq_len):
            # Reshape and calculate attention weights
            h_reshaped = h.view(batch_size, -1)
            encoder_output_reshaped = encoder_output.view(batch_size, seq_len_enc, -1)
            attention_weights = torch.bmm(encoder_output_reshaped, h_reshaped.unsqueeze(2)).squeeze(2)
            attention_weights = torch.softmax(attention_weights, dim=1)
            attention_applied = torch.bmm(attention_weights.unsqueeze(1), encoder_output_reshaped).squeeze(1)
            attention_applied = attention_applied.view(batch_size, hidden_channels, height, width)

            # LSTM cell update
            h, c = self.conv_lstm1(attention_applied, (h, c))
            h = self.batch_norm(h)
            h = self.dropout(h)
            outputs.append(h)

        output = self.conv(outputs[-1])
        output = output.unsqueeze(1)

        return output

# Seq2SeqAutoencoder类定义
class Seq2SeqAutoencoder(nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, kernel_size):
        super(Seq2SeqAutoencoder, self).__init__()
        self.encoder = ConvEncoder(input_channels, hidden_channels, kernel_size)
        self.decoder = ConvDecoder(hidden_channels, output_channels, kernel_size)

    def forward(self, input_tensor):
        encoder_output, (h, c) = self.encoder(input_tensor)
        output = self.decoder(encoder_output, h, c, seq_len=1)
        return output

# Model initialization
input_channels = 1
hidden_channels = 128
output_channels = 1
kernel_size = 3
sequence_length = 6

model = Seq2SeqAutoencoder(input_channels, hidden_channels, output_channels, kernel_size)

# 数据集类定义
class TyphoonDataset(Dataset):
    def __init__(self, image_paths, sequence_length=6, transform=None):
        self.image_paths = image_paths
        self.sequence_length = sequence_length
        self.transform = transform or transforms.Compose([
            transforms.Grayscale(),
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.image_paths) - self.sequence_length + 1

    def __getitem__(self, idx):
        sequence_paths = self.image_paths[idx:idx + self.sequence_length]
        imgs = [Image.open(img_path).convert('RGB') for img_path in sequence_paths]
        if self.transform:
            imgs = [self.transform(img) for img in imgs]
        sequences = torch.stack(imgs[:-1])
        target = imgs[-1]  # 确保目标图像是 [1, height, width]
        return sequences, target

# 加载并排序图像文件
def load_images_sorted(directory):
    images = []
    for root, _, files in os.walk(directory):
        for fname in files:
            if fname.endswith('.jpg'):
                full_path = os.path.join(root, fname)
                images.append(full_path)
    images.sort(key=extract_number)
    return images

# 从文件名中提取数字的辅助函数
def extract_number(filename):
    match = re.search(r'\d+', filename)
    return int(match.group()) if match else 0

# 从文件名中提取风暴ID和时间的辅助函数
def extract_storm_id_and_time(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    storm_id = parts[0]
    time = parts[1].split('.')[0]
    return storm_id, time

# 提取所有图像信息的辅助函数
def extract_all_image_info(image_paths):
    all_image_info = []
    for path in image_paths:
        storm_id, time = extract_storm_id_and_time(path)
        all_image_info.append((storm_id, time))
    return all_image_info

# 解压缩文件
def extract_zip(zip_path, extract_to):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)

# 数据目录
zip_path = '/root/data/storm.zip'
extract_to = '/root/data/stormoutput'
extract_zip(zip_path, extract_to)

all_images = load_images_sorted(extract_to)

# 按风暴ID分组
storm_groups = defaultdict(list)
for img in all_images:
    storm_id, _ = extract_storm_id_and_time(img)
    storm_groups[storm_id].append(img)

# 分别对每个风暴类型按照比例进行划分
train_images = []
val_images = []
split_ratio = 0.8

for storm_id, images in storm_groups.items():
    split_point = int(len(images) * split_ratio)
    train_images.extend(images[:split_point])
    val_images.extend(images[split_point:])

# 提取训练集和验证集的所有图片信息
train_image_info = extract_all_image_info(train_images)
val_image_info = extract_all_image_info(val_images)

# 创建数据集和数据加载器
train_dataset = TyphoonDataset(train_images, sequence_length=6)  # 包含6个输入帧和1个输出帧
val_dataset = TyphoonDataset(val_images, sequence_length=6)
train_loader = DataLoader(train_dataset, batch_size=6, shuffle=False, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=6, shuffle=False, drop_last=True)

# 测试模型
for sequences, target in train_loader:
    output = model(sequences)
    print("Output shape:", output.shape)  # 检查输出形状
    break


Output shape: torch.Size([6, 1, 1, 128, 128])


In [None]:
import torch
import torch.optim as optim
import matplotlib.pyplot as plt

# 初始化模型、损失函数和优化器
model = Seq2SeqAutoencoder(input_channels=1, hidden_channels=128, output_channels=1, kernel_size=3)
model.cuda()  # 如果有GPU
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.05)

train_losses = []
val_losses = []

num_epochs = 10  # 设置所需的epoch数

for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0
    for sequences, targets in train_loader:
        sequences = sequences.cuda()
        targets = targets.cuda()  # 确保目标维度是正确的 [batch_size, 1, height, width]
        optimizer.zero_grad()
        outputs = model(sequences)
        train_loss = criterion(outputs, targets)
        train_loss.backward()
        optimizer.step()
        total_train_loss += train_loss.item()
    train_losses.append(total_train_loss / len(train_loader))
    
    scheduler.step()  # Step the scheduler
    
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for sequences, targets in val_loader:
            sequences = sequences.cuda()
            targets = targets.cuda()
            outputs = model(sequences)
            val_loss = criterion(outputs, targets)
            total_val_loss += val_loss.item()
    val_losses.append(total_val_loss / len(val_loader))
    print(f'Epoch {epoch+1}, Train Loss: {total_train_loss / len(train_loader)}, Validation Loss: {total_val_loss / len(val_loader)}')

# 保存loss曲线图
plt.figure()
plt.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.savefig('/root/fangan4/loss_curve.png')
plt.show()


In [None]:
# 保存模型
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, '/root/fangan4/model4.pth')