In [None]:
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
import typing as t


import warnings
warnings.filterwarnings('ignore')

# Ячейка RNN

Перед вами стоит задача реализовать ячейку однослойной рекуррентной нейронной сети (RNN Cell).

На вход ячейка принимает вектор элемента входной последовательности $ x_i $ размерности $ d_x $ и вектор hidden state предыдущей ячейки $ h_{i-1} $ размерности $ d_h $, а на выходе отдает вектор элемента выходной последовательности $ y_i $ размерности $ d_y $ и собственный вектор hidden state $ h_i $ размерности $ d_h $.

Выходной Hidden State ячейки вычисляется по формуле $ h_i = \tanh(W_{xh} x_i +  W_{hh} h_{i-1} + b_{h}) $.

Выходной элемент вычисляется по формуле $ y_i = W_{hy}h_i + b_{y}$.

Реализуйте модуль `RNNCell`, принимающий $x_i$ `x`, $h_{i-1}$ `h_prev` и возвращающий кортеж из $(y_i, h_i)$.

Модуль `RNNCell` принимает в конструкторе натуральные числа - $d_x$ `input_dim`, $d_h$ `hidden_dim` и $d_y$ `output_dim`.

Не изменяйте имена и типы указанных в шаблоне атрибутов, чтобы проверяющая система смогла обработать решение.

In [None]:
class RNNCell(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        
        self.W_xh = nn.Parameter(torch.randn((hidden_dim, input_dim)))
        self.W_hh = nn.Parameter(torch.randn((hidden_dim, hidden_dim)))
        self.b_h = nn.Parameter(torch.randn((hidden_dim)))

        self.W_hy = nn.Parameter(torch.randn((output_dim, hidden_dim)))
        self.b_y = nn.Parameter(torch.randn((output_dim)))

    def forward(self, x: torch.Tensor, h_prev: torch.Tensor) -> t.Tuple[torch.Tensor, torch.Tensor]:
        ### ╰( ͡° ͜ʖ ͡° )つ──☆*:・ﾟ
        ### return y, h
        ...

In [None]:
with torch.no_grad():
    x = torch.Tensor([-0.2504,  0.9800,  0.5398,  1.3617, -1.5480,  0.8298, -0.7164, -0.1525])
    h_prev = torch.Tensor([
        -0.0677, -0.8649, -0.7603, -0.0229,  0.1861,  3.2758,  0.4451,  0.2717,
         1.4108,  0.0612, -0.2484, -1.5802, -1.1555,  0.9610, -2.1633,  0.0999
    ])
        
    cell = RNNCell(input_dim=8, hidden_dim=16, output_dim=8)
    cell.W_xh.copy_(torch.Tensor(
        [[ 0.0406,  0.2349,  0.0105,  0.2168,  0.1599,  0.0331, -0.0509, -0.0447],
        [ 0.1662,  0.0130, -0.0337, -0.1074, -0.1083,  0.1661,  0.0699,  0.0472],
        [-0.2137,  0.0695,  0.2388,  0.0276, -0.2023,  0.1200,  0.1877, -0.0928],
        [ 0.0446,  0.1321,  0.0169,  0.0440,  0.1799,  0.2268, -0.1644,  0.1853],
        [ 0.0688,  0.2237,  0.1891, -0.1934,  0.1648,  0.1595,  0.1382,  0.1909],
        [-0.0313, -0.1928,  0.0502, -0.0801,  0.1119,  0.2496,  0.0955, -0.1018],
        [ 0.1007,  0.1439, -0.1244, -0.0602, -0.0404, -0.0751,  0.0753, -0.2126],
        [ 0.1852, -0.0228,  0.0414,  0.1247, -0.0382, -0.2073, -0.0217,  0.1214],
        [ 0.1791,  0.1628,  0.1613,  0.0748, -0.1029,  0.2333,  0.0974, -0.1097],
        [ 0.2478,  0.1273, -0.0543,  0.2002, -0.2210,  0.1301, -0.0912,  0.2429],
        [ 0.0932,  0.0508, -0.0118, -0.0025,  0.1182,  0.0308, -0.0185,  0.0785],
        [ 0.2497,  0.2434, -0.2230,  0.2131, -0.0506, -0.1561, -0.0063,  0.0071],
        [-0.2488, -0.2414, -0.2098, -0.1004, -0.0247, -0.2486,  0.0511, -0.1239],
        [ 0.0170,  0.1625, -0.2013, -0.1176,  0.1106,  0.1020,  0.1188, -0.0986],
        [ 0.0347,  0.0508, -0.2134,  0.1188, -0.1166, -0.1257, -0.0018, -0.1294],
        [ 0.1914, -0.1379,  0.2048, -0.1244,  0.0215,  0.0065,  0.1909,  0.1165]]
        
    ))
    cell.W_hh.copy_(torch.Tensor(
        [[ 0.0478, -0.1842,  0.0311, -0.1609,  0.1658,  0.1759,  0.0606,  0.1576,
          0.0005, -0.1109, -0.0180,  0.0627, -0.2300, -0.1832, -0.0912,  0.2232],
        [ 0.0773, -0.2239, -0.1193, -0.2469,  0.2075, -0.1265,  0.0925, -0.1849,
          0.2264, -0.1034, -0.2423,  0.1826,  0.2086, -0.0932,  0.0459, -0.0289],
        [ 0.0346,  0.0026,  0.2352, -0.1480,  0.1724, -0.1090,  0.2177,  0.1470,
         -0.1529, -0.1101,  0.0953, -0.1780, -0.2014,  0.1175,  0.0034, -0.2008],
        [-0.0272, -0.2022,  0.0973, -0.0868,  0.0937, -0.0744, -0.2322,  0.1686,
          0.2041,  0.2289, -0.1019, -0.1307,  0.1082,  0.1127,  0.0810, -0.0212],
        [-0.0625, -0.1929,  0.1032,  0.0717, -0.2095, -0.0833,  0.2060,  0.0393,
         -0.0017, -0.2424,  0.1011,  0.1084, -0.0507,  0.0380,  0.1606, -0.1064],
        [ 0.2354,  0.1944, -0.0348, -0.0791, -0.1073,  0.0490, -0.0145,  0.2091,
         -0.0509,  0.0794,  0.1676, -0.0799,  0.0610,  0.0923,  0.1328, -0.0943],
        [-0.1152, -0.1263, -0.0177, -0.1844,  0.2209, -0.1556, -0.0231, -0.2491,
         -0.1764,  0.0462,  0.0301, -0.0514, -0.0831,  0.2381, -0.0296, -0.1493],
        [ 0.0875,  0.0682, -0.2049,  0.2127,  0.1225, -0.0340,  0.2139, -0.1940,
         -0.1596,  0.2279,  0.0067, -0.2343,  0.0280, -0.0747, -0.0147,  0.1535],
        [-0.0293,  0.1301,  0.1968, -0.1572,  0.0800, -0.2406, -0.2052, -0.1352,
          0.0198, -0.1625, -0.2067,  0.0007, -0.2270, -0.2007, -0.0817,  0.0085],
        [-0.0042,  0.1470,  0.1297, -0.1006, -0.1358, -0.1872, -0.2042,  0.2296,
         -0.2233, -0.1409,  0.1878,  0.0304,  0.2365,  0.1029,  0.1432,  0.2008],
        [ 0.1693, -0.0513, -0.0203, -0.0693,  0.1079, -0.0635, -0.1428,  0.0494,
          0.0802,  0.1385, -0.0716,  0.0783, -0.0984, -0.1418, -0.0036,  0.0212],
        [ 0.2301,  0.1138, -0.1399,  0.1093, -0.1122, -0.2258, -0.2126,  0.2026,
          0.2073, -0.1753, -0.0526, -0.2077,  0.1696, -0.0842, -0.0802,  0.2134],
        [-0.1557, -0.0470, -0.1055, -0.0037, -0.1743, -0.0108,  0.2397,  0.1491,
          0.1103,  0.1330,  0.2384,  0.1541, -0.0777,  0.2000,  0.1809, -0.1571],
        [ 0.1807,  0.1401,  0.2319, -0.2321,  0.1184, -0.2419, -0.0857, -0.0775,
         -0.0775,  0.0141, -0.0578,  0.0138,  0.1883,  0.0045, -0.1924, -0.2058],
        [-0.1946,  0.0101,  0.1604, -0.1925,  0.0242, -0.1358,  0.0951, -0.0285,
         -0.2319,  0.1631, -0.0933,  0.1820, -0.0486, -0.1240,  0.0724,  0.1316],
        [-0.0019,  0.0008, -0.0875,  0.1283, -0.0355,  0.0270, -0.0059,  0.2113,
         -0.2407, -0.1975,  0.2467,  0.1493, -0.1066, -0.1537,  0.1096, -0.2061]]
    ))
    cell.b_h.copy_(torch.Tensor(
        [ 0.0386, -0.2340, -0.1360, -0.0949, -0.0941,  0.1709,  0.2353,  0.2209,
         0.1486, -0.2328, -0.0414,  0.2084, -0.0596,  0.0189,  0.1878, -0.0778]
    ))
    cell.W_hy.copy_(torch.Tensor(
        [[ 0.0638, -0.0497,  0.0741, -0.2310, -0.1238, -0.1589, -0.0043,  0.0872,
         -0.0911, -0.1578,  0.1806,  0.1879, -0.1788,  0.0789, -0.1583, -0.0673],
        [ 0.2296, -0.1817,  0.1387,  0.0339,  0.2427,  0.0423, -0.1245, -0.1397,
          0.1316, -0.2366, -0.0957,  0.1203,  0.1283,  0.0793, -0.0440, -0.1760],
        [ 0.0335, -0.0287,  0.1878,  0.0892, -0.0537, -0.0460,  0.2408,  0.0165,
          0.2168, -0.0702,  0.1840,  0.1339, -0.2065,  0.1920, -0.1744, -0.2482],
        [ 0.0797, -0.0493,  0.1970, -0.0941,  0.1985, -0.1668,  0.1411, -0.2062,
          0.2036, -0.1010,  0.1353, -0.0030,  0.2277, -0.2483, -0.0596, -0.1353],
        [-0.1743, -0.1762, -0.0354,  0.1196,  0.2114, -0.0111,  0.1846,  0.0568,
          0.1948, -0.0663, -0.1427, -0.1866, -0.0521, -0.1291,  0.2172,  0.2269],
        [-0.1905, -0.0812, -0.0107,  0.0155,  0.0824,  0.0845,  0.0869,  0.1338,
          0.2164,  0.1928, -0.1952, -0.2499, -0.2178, -0.0079, -0.1099, -0.1504],
        [ 0.1479,  0.1644, -0.2477,  0.1681, -0.0515,  0.1599,  0.1218,  0.1776,
          0.2301, -0.1107, -0.1476,  0.2119,  0.1169, -0.1896,  0.1747, -0.1166],
        [-0.0299,  0.1371,  0.0456, -0.2152,  0.0322, -0.2070,  0.1516, -0.2265,
          0.0596, -0.0044,  0.1469, -0.0675,  0.2118,  0.0634,  0.1256,  0.1608]]
    ))
    cell.b_y.copy_(torch.Tensor(
        [-0.0615, -0.2446, -0.0401,  0.1929, -0.1269, -0.1158, -0.0044,  0.1368]
    ))
    y, h = cell(x, h_prev)
    
    h_true = torch.Tensor(
        [ 0.8863, -0.5561,  0.4051,  0.2471, -0.7160, -0.3247, -0.0295,  0.3529,
        -0.0904, -0.8400, -0.3505,  0.2327, -0.6186, -0.8665, -0.7326, -0.8436]
    )
    y_true = torch.Tensor([ 0.5030,  0.1724,  0.4603,  0.4330, -0.4631, -0.1040,  0.2561, -0.5234])
    assert torch.allclose(h, h_true, atol=1e-4)
    assert torch.allclose(y, y_true, atol=1e-4)

# RNN

Перед вами стоит задача реализовать модуль однослойной рекуррентной нейронной сети (RNN).

На вход модуль принимает матрицу входной последовательности $X$ размерности $(n, d_{x})$ и вектор начального hidden state $h_0$ размерности $d_h$, а на выходе отдает матрицу выходной последовательности $Y$ размерности $(n, d_{y})$ и вектор конечного hidden state $h_n$ размерности $d_h$.

Реализуйте модуль `RNN`, принимающий $X$ `x`, $h_{0}$ `h_initial` и возвращающий кортеж из $(Y, h_{n})$. 

Переиспользуйте уже готовый модуль `RNNCell` из предыдущей задачи.

Модуль `RNN` также принимает в конструкторе натуральные числа - $d_x$ `input_dim`, $d_h$ `hidden_dim` и $d_y$ `output_dim`.

Не изменяйте имена и типы указанных в шаблоне атрибутов, чтобы проверяющая система смогла обработать решение.

In [None]:
class RNN(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        
        self.cell = RNNCell(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim)

    def forward(self, x: torch.Tensor, h_initial: torch.Tensor) -> t.Tuple[torch.Tensor, torch.Tensor]:
        ### ╰( ͡° ͜ʖ ͡° )つ──☆*:・ﾟ
        ### return y, h_last
        ...

In [None]:
def setup_rnn():
    rnn = RNN(input_dim=8, hidden_dim=16, output_dim=8)
    rnn.cell.W_xh.copy_(torch.Tensor(
        [[ 0.0406,  0.2349,  0.0105,  0.2168,  0.1599,  0.0331, -0.0509, -0.0447],
        [ 0.1662,  0.0130, -0.0337, -0.1074, -0.1083,  0.1661,  0.0699,  0.0472],
        [-0.2137,  0.0695,  0.2388,  0.0276, -0.2023,  0.1200,  0.1877, -0.0928],
        [ 0.0446,  0.1321,  0.0169,  0.0440,  0.1799,  0.2268, -0.1644,  0.1853],
        [ 0.0688,  0.2237,  0.1891, -0.1934,  0.1648,  0.1595,  0.1382,  0.1909],
        [-0.0313, -0.1928,  0.0502, -0.0801,  0.1119,  0.2496,  0.0955, -0.1018],
        [ 0.1007,  0.1439, -0.1244, -0.0602, -0.0404, -0.0751,  0.0753, -0.2126],
        [ 0.1852, -0.0228,  0.0414,  0.1247, -0.0382, -0.2073, -0.0217,  0.1214],
        [ 0.1791,  0.1628,  0.1613,  0.0748, -0.1029,  0.2333,  0.0974, -0.1097],
        [ 0.2478,  0.1273, -0.0543,  0.2002, -0.2210,  0.1301, -0.0912,  0.2429],
        [ 0.0932,  0.0508, -0.0118, -0.0025,  0.1182,  0.0308, -0.0185,  0.0785],
        [ 0.2497,  0.2434, -0.2230,  0.2131, -0.0506, -0.1561, -0.0063,  0.0071],
        [-0.2488, -0.2414, -0.2098, -0.1004, -0.0247, -0.2486,  0.0511, -0.1239],
        [ 0.0170,  0.1625, -0.2013, -0.1176,  0.1106,  0.1020,  0.1188, -0.0986],
        [ 0.0347,  0.0508, -0.2134,  0.1188, -0.1166, -0.1257, -0.0018, -0.1294],
        [ 0.1914, -0.1379,  0.2048, -0.1244,  0.0215,  0.0065,  0.1909,  0.1165]]
        
    ))
    rnn.cell.W_hh.copy_(torch.Tensor(
        [[ 0.0478, -0.1842,  0.0311, -0.1609,  0.1658,  0.1759,  0.0606,  0.1576,
          0.0005, -0.1109, -0.0180,  0.0627, -0.2300, -0.1832, -0.0912,  0.2232],
        [ 0.0773, -0.2239, -0.1193, -0.2469,  0.2075, -0.1265,  0.0925, -0.1849,
          0.2264, -0.1034, -0.2423,  0.1826,  0.2086, -0.0932,  0.0459, -0.0289],
        [ 0.0346,  0.0026,  0.2352, -0.1480,  0.1724, -0.1090,  0.2177,  0.1470,
         -0.1529, -0.1101,  0.0953, -0.1780, -0.2014,  0.1175,  0.0034, -0.2008],
        [-0.0272, -0.2022,  0.0973, -0.0868,  0.0937, -0.0744, -0.2322,  0.1686,
          0.2041,  0.2289, -0.1019, -0.1307,  0.1082,  0.1127,  0.0810, -0.0212],
        [-0.0625, -0.1929,  0.1032,  0.0717, -0.2095, -0.0833,  0.2060,  0.0393,
         -0.0017, -0.2424,  0.1011,  0.1084, -0.0507,  0.0380,  0.1606, -0.1064],
        [ 0.2354,  0.1944, -0.0348, -0.0791, -0.1073,  0.0490, -0.0145,  0.2091,
         -0.0509,  0.0794,  0.1676, -0.0799,  0.0610,  0.0923,  0.1328, -0.0943],
        [-0.1152, -0.1263, -0.0177, -0.1844,  0.2209, -0.1556, -0.0231, -0.2491,
         -0.1764,  0.0462,  0.0301, -0.0514, -0.0831,  0.2381, -0.0296, -0.1493],
        [ 0.0875,  0.0682, -0.2049,  0.2127,  0.1225, -0.0340,  0.2139, -0.1940,
         -0.1596,  0.2279,  0.0067, -0.2343,  0.0280, -0.0747, -0.0147,  0.1535],
        [-0.0293,  0.1301,  0.1968, -0.1572,  0.0800, -0.2406, -0.2052, -0.1352,
          0.0198, -0.1625, -0.2067,  0.0007, -0.2270, -0.2007, -0.0817,  0.0085],
        [-0.0042,  0.1470,  0.1297, -0.1006, -0.1358, -0.1872, -0.2042,  0.2296,
         -0.2233, -0.1409,  0.1878,  0.0304,  0.2365,  0.1029,  0.1432,  0.2008],
        [ 0.1693, -0.0513, -0.0203, -0.0693,  0.1079, -0.0635, -0.1428,  0.0494,
          0.0802,  0.1385, -0.0716,  0.0783, -0.0984, -0.1418, -0.0036,  0.0212],
        [ 0.2301,  0.1138, -0.1399,  0.1093, -0.1122, -0.2258, -0.2126,  0.2026,
          0.2073, -0.1753, -0.0526, -0.2077,  0.1696, -0.0842, -0.0802,  0.2134],
        [-0.1557, -0.0470, -0.1055, -0.0037, -0.1743, -0.0108,  0.2397,  0.1491,
          0.1103,  0.1330,  0.2384,  0.1541, -0.0777,  0.2000,  0.1809, -0.1571],
        [ 0.1807,  0.1401,  0.2319, -0.2321,  0.1184, -0.2419, -0.0857, -0.0775,
         -0.0775,  0.0141, -0.0578,  0.0138,  0.1883,  0.0045, -0.1924, -0.2058],
        [-0.1946,  0.0101,  0.1604, -0.1925,  0.0242, -0.1358,  0.0951, -0.0285,
         -0.2319,  0.1631, -0.0933,  0.1820, -0.0486, -0.1240,  0.0724,  0.1316],
        [-0.0019,  0.0008, -0.0875,  0.1283, -0.0355,  0.0270, -0.0059,  0.2113,
         -0.2407, -0.1975,  0.2467,  0.1493, -0.1066, -0.1537,  0.1096, -0.2061]]
    ))
    rnn.cell.b_h.copy_(torch.Tensor(
        [ 0.0386, -0.2340, -0.1360, -0.0949, -0.0941,  0.1709,  0.2353,  0.2209,
         0.1486, -0.2328, -0.0414,  0.2084, -0.0596,  0.0189,  0.1878, -0.0778]
    ))
    rnn.cell.W_hy.copy_(torch.Tensor(
        [[ 0.0638, -0.0497,  0.0741, -0.2310, -0.1238, -0.1589, -0.0043,  0.0872,
         -0.0911, -0.1578,  0.1806,  0.1879, -0.1788,  0.0789, -0.1583, -0.0673],
        [ 0.2296, -0.1817,  0.1387,  0.0339,  0.2427,  0.0423, -0.1245, -0.1397,
          0.1316, -0.2366, -0.0957,  0.1203,  0.1283,  0.0793, -0.0440, -0.1760],
        [ 0.0335, -0.0287,  0.1878,  0.0892, -0.0537, -0.0460,  0.2408,  0.0165,
          0.2168, -0.0702,  0.1840,  0.1339, -0.2065,  0.1920, -0.1744, -0.2482],
        [ 0.0797, -0.0493,  0.1970, -0.0941,  0.1985, -0.1668,  0.1411, -0.2062,
          0.2036, -0.1010,  0.1353, -0.0030,  0.2277, -0.2483, -0.0596, -0.1353],
        [-0.1743, -0.1762, -0.0354,  0.1196,  0.2114, -0.0111,  0.1846,  0.0568,
          0.1948, -0.0663, -0.1427, -0.1866, -0.0521, -0.1291,  0.2172,  0.2269],
        [-0.1905, -0.0812, -0.0107,  0.0155,  0.0824,  0.0845,  0.0869,  0.1338,
          0.2164,  0.1928, -0.1952, -0.2499, -0.2178, -0.0079, -0.1099, -0.1504],
        [ 0.1479,  0.1644, -0.2477,  0.1681, -0.0515,  0.1599,  0.1218,  0.1776,
          0.2301, -0.1107, -0.1476,  0.2119,  0.1169, -0.1896,  0.1747, -0.1166],
        [-0.0299,  0.1371,  0.0456, -0.2152,  0.0322, -0.2070,  0.1516, -0.2265,
          0.0596, -0.0044,  0.1469, -0.0675,  0.2118,  0.0634,  0.1256,  0.1608]]
    ))
    rnn.cell.b_y.copy_(torch.Tensor(
        [-0.0615, -0.2446, -0.0401,  0.1929, -0.1269, -0.1158, -0.0044,  0.1368]
    ))
    return rnn

with torch.no_grad():
    x = torch.Tensor(
        [[-1.2776,  0.5119, -1.1238,  0.9175,  0.6008,  0.5584, -0.7533, -0.4042],
        [-0.6983, -0.4434,  0.7114, -0.4369,  0.3358,  0.5429,  0.3797, -1.1896],
        [ 0.8967,  0.0792, -0.5119, -2.5554, -1.0079, -0.4060, -0.9058, -0.3935],
        [ 1.3755, -0.0443,  0.7778,  1.0423, -1.1472, -1.0362, -1.6468,  1.9422],
        [ 1.8704,  1.2734, -0.1887, -0.1851, -0.3144, -0.8346, -0.9275, -0.3772]]
    )
    h_0 = torch.Tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

    rnn = setup_rnn()
    y, h_last = rnn(x, h_0)
    
    y_true = torch.Tensor(
        [[ 0.0292,  0.0744,  0.2062,  0.1836, -0.3019, -0.2886,  0.3392, -0.0362],
        [-0.2185,  0.0339, -0.1909,  0.4536,  0.1540, -0.0730, -0.1722,  0.3190],
        [ 0.0788, -0.4049,  0.1169,  0.1183, -0.0324,  0.0035,  0.1992,  0.2993],
        [ 0.0203, -0.6270, -0.1633,  0.0290, -0.0267, -0.0026,  0.3516, -0.0876],
        [ 0.0800, -0.4304,  0.1784,  0.0623, -0.0518, -0.2945,  0.6269,  0.1211]]
    )
    h_last_true = torch.Tensor(
        [ 0.4946, -0.0356, -0.6983,  0.1682, -0.1134, -0.3129,  0.3878,  0.5568,
         0.4668,  0.3385,  0.4887,  0.8589, -0.2894, -0.0281,  0.6881,  0.0917]
    )
    assert torch.allclose(y, y_true, atol=1e-4)
    assert torch.allclose(h_last, h_last_true, atol=1e-4)

# Residual Block

Перед вами стоит задача реализовать модуль проброса градиента тензора $X_{in}$ после выполнения каких-либо преобразований, сохраняющих его размерность, над ним.


Реализуйте модуль `ResidualBlock`, принимающий тензор $X_{in}$ `x` и возвращающий тензор $X_{out}$.

В качестве функции активации используйте $ReLU$. Подблок `subblock`, выполняющий преобразования над $X$ и по сути являющийся `nn.Module`, модуль `ResidualBlock` должен принимать в конструкторе.



In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, subblock: nn.Module):
        super().__init__()
        
        self.subblock = subblock

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ### ╰( ͡° ͜ʖ ͡° )つ──☆*:・ﾟ
        ### return x_out
        ...

In [None]:
class TestSubmodule(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor():
        return x * 0.005 + 1

with torch.no_grad():
    x = torch.Tensor(
        [[[-1.5638,  2.1072, -0.3607],
         [-2.4503,  1.3358,  0.5409]],

        [[ 0.0583, -0.8691, -1.5175],
         [ 1.6005, -0.5242,  0.6406]],

        [[ 0.6401, -0.7931,  0.3953],
         [-0.1662,  1.2120,  1.9067]],

        [[-0.2729, -1.4902, -0.2814],
         [-2.7286, -1.6104,  0.9106]],

        [[-0.3986, -0.0439, -0.3940],
         [-2.2789, -0.8279,  1.2199]]]
    )
    
    subblock = TestSubmodule()
    res = ResidualBlock(subblock)
    
    x_out = res(x)
    
    x_out_true = torch.Tensor(
    [[[0.0000, 3.1177, 0.6375],
         [0.0000, 2.3425, 1.5436]],

        [[1.0586, 0.1266, 0.0000],
         [2.6085, 0.4732, 1.6438]],

        [[1.6433, 0.2029, 1.3973],
         [0.8330, 2.2181, 2.9162]],

        [[0.7257, 0.0000, 0.7172],
         [0.0000, 0.0000, 1.9152]],

        [[0.5994, 0.9559, 0.6040],
         [0.0000, 0.1680, 2.2260]]]
    )
    assert torch.allclose(x_out, x_out_true, atol=1e-4)

# Языковая модель на RNN

Реализуйте модуль `LanguageModel`, принимающий последовательность слов `words: List[str]` (например, `["привет", "дивный", "новый", "мир"]`) и возвращающий вектор - распределение вероятностей следующего слова/токена по всем возможным объектам словаря.

Модуль должен принимать в конструкторе словарь `word_dictionary: List[str]`.

В нём изначально не будет содержаться специальных токенов. Вам следует добавить их и присвоить им порядковые номера: `<bos>` - 0, `<eos>` - 1, `<unk>` - 2.

Все последующие токены должны иметь порядковые номера, начинающиеся с 3, в соответствии с порядком элементов в `word_dictionary`.

Также в конструкторе передается `input_dim` и `hidden_dim` для RNN.

Порядковый номер вероятности в выходном векторе должен соответствовать порядковому номеру объекта из получившегося словаря.

Переиспользуйте уже реализованный вами в предыдущей задаче модуль `RNN`. В качестве $h_0$ используйте вектор, состоящий из нулей. Не забудьте, что выходных $y$ у RNN столько же, сколько входных $x$, а нужно получить распредение для единственного токена.

Не изменяйте имена и типы указанных в шаблоне атрибутов, чтобы проверяющая система смогла обработать решение.

In [None]:
class LanguageModel(nn.Module):
    def __init__(self, word_dictionary: t.List[str], input_dim: int, hidden_dim: int):
        super().__init__()
        self.embed = nn.Embedding(...)  # тут надо что-то поменять
        self.rnn = RNN(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=...)  # тут надо что-то поменять
        ### ╰( ͡° ͜ʖ ͡° )つ──☆*:・ﾟ

    def forward(self, words: t.List[str]) -> torch.Tensor:
        ### ╰( ͡° ͜ʖ ͡° )つ──☆*:・ﾟ
        return

In [None]:
def setup_lm():
    lm = LanguageModel(
        word_dictionary=['привет', 'собака', 'кролик', 'пень', 'дивный', 'новый', 'мир', 'я', 'а', 'не'],
        input_dim=8,
        hidden_dim=16
    )
    
    lm.embed.weight.copy_(torch.Tensor(
        [[ 0.3622, -0.4547, -0.6780,  0.2229,  0.7924, -0.4368,  1.9774,  0.9146],
         [ 1.4505, -0.6534, -0.0448, -1.2961,  2.5839, -0.0052,  0.2903, -0.3439],
         [ 1.2504, -1.0109, -1.3156,  0.2381,  2.0347, -0.9224, -2.2758,  0.1573],
         [ 0.1574,  0.7840, -0.0453,  1.1687,  0.1508,  0.2849,  0.7311, -0.8141],
         [ 1.1858,  1.9815, -1.0088, -0.4776,  1.1550,  2.2777,  0.4954, -0.1482],
         [ 1.7633,  0.3523,  1.0210, -0.7210,  0.1320,  0.3525, -0.9463,  1.0663],
         [-1.2583,  0.6247, -1.3976, -0.0705, -0.7252,  0.6046, -0.7657, -1.0576],
         [-0.0527, -1.0089, -1.9000, -0.1904,  1.2083,  0.6260,  0.0978,  0.1958],
         [-0.4728, -0.3897,  1.2230, -1.7291, -0.4821,  1.3033, -0.6947, -2.0038],
         [-1.2069, -0.7040,  2.3203, -1.0961,  0.1280,  0.2123, -0.0268,  0.2074],
         [-0.6628, -0.0188,  0.1210,  0.3474,  1.5741,  0.3325, -0.5522,  0.8874],
         [-1.1994,  1.3444, -0.2977, -0.7457,  2.0891, -0.6886, -0.2262,  0.3408],
         [-0.7926,  0.5283, -0.1947, -1.5606, -0.3685,  0.4950,  0.1381,  0.5355]]
    ))
    
    
    lm.rnn.cell.W_xh.copy_(torch.Tensor(
        [[ 0.0406,  0.2349,  0.0105,  0.2168,  0.1599,  0.0331, -0.0509, -0.0447],
        [ 0.1662,  0.0130, -0.0337, -0.1074, -0.1083,  0.1661,  0.0699,  0.0472],
        [-0.2137,  0.0695,  0.2388,  0.0276, -0.2023,  0.1200,  0.1877, -0.0928],
        [ 0.0446,  0.1321,  0.0169,  0.0440,  0.1799,  0.2268, -0.1644,  0.1853],
        [ 0.0688,  0.2237,  0.1891, -0.1934,  0.1648,  0.1595,  0.1382,  0.1909],
        [-0.0313, -0.1928,  0.0502, -0.0801,  0.1119,  0.2496,  0.0955, -0.1018],
        [ 0.1007,  0.1439, -0.1244, -0.0602, -0.0404, -0.0751,  0.0753, -0.2126],
        [ 0.1852, -0.0228,  0.0414,  0.1247, -0.0382, -0.2073, -0.0217,  0.1214],
        [ 0.1791,  0.1628,  0.1613,  0.0748, -0.1029,  0.2333,  0.0974, -0.1097],
        [ 0.2478,  0.1273, -0.0543,  0.2002, -0.2210,  0.1301, -0.0912,  0.2429],
        [ 0.0932,  0.0508, -0.0118, -0.0025,  0.1182,  0.0308, -0.0185,  0.0785],
        [ 0.2497,  0.2434, -0.2230,  0.2131, -0.0506, -0.1561, -0.0063,  0.0071],
        [-0.2488, -0.2414, -0.2098, -0.1004, -0.0247, -0.2486,  0.0511, -0.1239],
        [ 0.0170,  0.1625, -0.2013, -0.1176,  0.1106,  0.1020,  0.1188, -0.0986],
        [ 0.0347,  0.0508, -0.2134,  0.1188, -0.1166, -0.1257, -0.0018, -0.1294],
        [ 0.1914, -0.1379,  0.2048, -0.1244,  0.0215,  0.0065,  0.1909,  0.1165]]
        
    ))
    lm.rnn.cell.W_hh.copy_(torch.Tensor(
        [[ 0.0478, -0.1842,  0.0311, -0.1609,  0.1658,  0.1759,  0.0606,  0.1576,
          0.0005, -0.1109, -0.0180,  0.0627, -0.2300, -0.1832, -0.0912,  0.2232],
        [ 0.0773, -0.2239, -0.1193, -0.2469,  0.2075, -0.1265,  0.0925, -0.1849,
          0.2264, -0.1034, -0.2423,  0.1826,  0.2086, -0.0932,  0.0459, -0.0289],
        [ 0.0346,  0.0026,  0.2352, -0.1480,  0.1724, -0.1090,  0.2177,  0.1470,
         -0.1529, -0.1101,  0.0953, -0.1780, -0.2014,  0.1175,  0.0034, -0.2008],
        [-0.0272, -0.2022,  0.0973, -0.0868,  0.0937, -0.0744, -0.2322,  0.1686,
          0.2041,  0.2289, -0.1019, -0.1307,  0.1082,  0.1127,  0.0810, -0.0212],
        [-0.0625, -0.1929,  0.1032,  0.0717, -0.2095, -0.0833,  0.2060,  0.0393,
         -0.0017, -0.2424,  0.1011,  0.1084, -0.0507,  0.0380,  0.1606, -0.1064],
        [ 0.2354,  0.1944, -0.0348, -0.0791, -0.1073,  0.0490, -0.0145,  0.2091,
         -0.0509,  0.0794,  0.1676, -0.0799,  0.0610,  0.0923,  0.1328, -0.0943],
        [-0.1152, -0.1263, -0.0177, -0.1844,  0.2209, -0.1556, -0.0231, -0.2491,
         -0.1764,  0.0462,  0.0301, -0.0514, -0.0831,  0.2381, -0.0296, -0.1493],
        [ 0.0875,  0.0682, -0.2049,  0.2127,  0.1225, -0.0340,  0.2139, -0.1940,
         -0.1596,  0.2279,  0.0067, -0.2343,  0.0280, -0.0747, -0.0147,  0.1535],
        [-0.0293,  0.1301,  0.1968, -0.1572,  0.0800, -0.2406, -0.2052, -0.1352,
          0.0198, -0.1625, -0.2067,  0.0007, -0.2270, -0.2007, -0.0817,  0.0085],
        [-0.0042,  0.1470,  0.1297, -0.1006, -0.1358, -0.1872, -0.2042,  0.2296,
         -0.2233, -0.1409,  0.1878,  0.0304,  0.2365,  0.1029,  0.1432,  0.2008],
        [ 0.1693, -0.0513, -0.0203, -0.0693,  0.1079, -0.0635, -0.1428,  0.0494,
          0.0802,  0.1385, -0.0716,  0.0783, -0.0984, -0.1418, -0.0036,  0.0212],
        [ 0.2301,  0.1138, -0.1399,  0.1093, -0.1122, -0.2258, -0.2126,  0.2026,
          0.2073, -0.1753, -0.0526, -0.2077,  0.1696, -0.0842, -0.0802,  0.2134],
        [-0.1557, -0.0470, -0.1055, -0.0037, -0.1743, -0.0108,  0.2397,  0.1491,
          0.1103,  0.1330,  0.2384,  0.1541, -0.0777,  0.2000,  0.1809, -0.1571],
        [ 0.1807,  0.1401,  0.2319, -0.2321,  0.1184, -0.2419, -0.0857, -0.0775,
         -0.0775,  0.0141, -0.0578,  0.0138,  0.1883,  0.0045, -0.1924, -0.2058],
        [-0.1946,  0.0101,  0.1604, -0.1925,  0.0242, -0.1358,  0.0951, -0.0285,
         -0.2319,  0.1631, -0.0933,  0.1820, -0.0486, -0.1240,  0.0724,  0.1316],
        [-0.0019,  0.0008, -0.0875,  0.1283, -0.0355,  0.0270, -0.0059,  0.2113,
         -0.2407, -0.1975,  0.2467,  0.1493, -0.1066, -0.1537,  0.1096, -0.2061]]
    ))
    lm.rnn.cell.b_h.copy_(torch.Tensor(
        [ 0.0386, -0.2340, -0.1360, -0.0949, -0.0941,  0.1709,  0.2353,  0.2209,
         0.1486, -0.2328, -0.0414,  0.2084, -0.0596,  0.0189,  0.1878, -0.0778]
    ))
    lm.rnn.cell.W_hy.copy_(torch.Tensor(
        [[-0.0462,  0.2208, -0.1284, -0.0817, -0.0008,  0.2419,  0.0263, -0.1479,
         -0.1397, -0.0603, -0.0509,  0.1095,  0.0603, -0.1729, -0.0180,  0.0686],
        [-0.0227,  0.0257, -0.1776, -0.1790,  0.0879, -0.0701,  0.2108,  0.0953,
         -0.0552, -0.1414, -0.0260, -0.0821,  0.0494,  0.0508,  0.0885, -0.0157],
        [ 0.0320, -0.0050, -0.1238, -0.1552,  0.1842, -0.1297,  0.0862,  0.1281,
         -0.0428, -0.2060, -0.2142,  0.1671,  0.0621,  0.0634,  0.1296, -0.1988],
        [ 0.2323,  0.1116, -0.0542,  0.1159,  0.2023,  0.1368,  0.2135,  0.2343,
         -0.1120,  0.2311, -0.1254, -0.0499, -0.0044, -0.1102, -0.1216, -0.2289],
        [-0.1302, -0.1570, -0.1188, -0.1822,  0.2372,  0.0932,  0.0379, -0.1406,
         -0.0212,  0.1506, -0.2328, -0.1376, -0.1314, -0.1642,  0.2360,  0.1479],
        [-0.1579,  0.2255,  0.1142,  0.0890,  0.1256, -0.0717,  0.0024,  0.0214,
          0.0239, -0.2026,  0.1472,  0.0751,  0.0851, -0.0103, -0.2098,  0.1781],
        [-0.2182,  0.0330, -0.0726, -0.0199,  0.1129,  0.1696, -0.0408,  0.1237,
          0.1556,  0.1023, -0.0907, -0.0040, -0.0122, -0.1988,  0.0704, -0.0694],
        [ 0.0143, -0.1247,  0.1846,  0.1248, -0.2228, -0.2073, -0.1841, -0.0053,
          0.0968, -0.1930, -0.2422,  0.0636,  0.1312, -0.0352, -0.2495, -0.1774],
        [ 0.0025,  0.1152, -0.1659, -0.0773, -0.0250,  0.0786, -0.2134, -0.2233,
          0.2255, -0.0485,  0.1649, -0.2103, -0.1832,  0.2490, -0.1452, -0.0236],
        [-0.1660, -0.2005, -0.0618, -0.0775, -0.1834, -0.1649, -0.2391, -0.0353,
          0.1663, -0.2007, -0.1393, -0.1652,  0.2421,  0.0932, -0.1770,  0.1883],
        [ 0.1616,  0.0254, -0.0749, -0.1515,  0.0366, -0.1876, -0.1172,  0.0925,
          0.2407,  0.0854,  0.1835, -0.0675, -0.0365,  0.1523, -0.1558, -0.2135],
        [ 0.2083,  0.0766, -0.0294, -0.1983,  0.0559, -0.0492,  0.1288,  0.1547,
          0.0206, -0.0386,  0.0681,  0.0390,  0.0114,  0.0146, -0.1899,  0.0131],
        [-0.2321,  0.0265, -0.0242, -0.0816,  0.0285,  0.2046,  0.1819, -0.2399,
         -0.0004,  0.2073, -0.2124, -0.0300,  0.0061, -0.0938, -0.0983,  0.1197]]
    ))
    lm.rnn.cell.b_y.copy_(torch.Tensor(
        [ 0.1920, -0.2377, -0.0254,  0.0027, -0.0628, -0.1779,  0.0603, -0.1282,
         0.0612, -0.1540, -0.1511,  0.1362, -0.2046]
    ))
    return lm

with torch.no_grad():
    lm = setup_lm()
    out = lm(['я', 'собака', 'а', 'не'])
    out_true = torch.Tensor(
        [0.1027, 0.0676, 0.0782, 0.0787, 0.0785, 0.0824, 0.0731, 0.0563, 0.0876,
        0.0616, 0.0564, 0.0994, 0.0776]
    )
    assert torch.allclose(out, out_true, atol=1e-4)