Skip to content

Commit

Permalink
More work in simplifying matplotlib imports
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescAlted committed Sep 23, 2016
1 parent b16bffc commit 41c18c0
Showing 1 changed file with 23 additions and 31 deletions.
54 changes: 23 additions & 31 deletions reflexible/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,18 @@

import yaml
import numpy as np
import matplotlib as mpl
from matplotlib.collections import LineCollection

try:
from mpl_toolkits.basemap import Basemap
except ImportError:
from matplotlib.toolkits.basemap import Basemap
from netCDF4 import Dataset as NetCDFFile
from PIL import Image

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap, shiftgrid
from netCDF4 import Dataset as NetCDFFile
from mpl_toolkits import basemap

# mpl.use("Agg")
# mp.interactive(False)
# mpl.use('Agg')
# !!NEED TO FIX!! #
from matplotlib.ticker import NullFormatter

# TODO: !!NEED TO FIX!!
TEX = False

__author__ = "John F Burkhart <jfburkhart@gmail.com>"
Expand Down Expand Up @@ -367,7 +362,7 @@ def plot_ECMWF(nc, variable, time, level, map_region='north_america'):
data = alldata[time, level, :, :]
data = data[::-1, :] # flip lats
# shift the grid to go from -180 to 180
d2, lon2 = shiftgrid(-180, data, lon1)
d2, lon2 = basemap.shiftgrid(-180, data, lon1)
plot_grid((lon2, lat2, d2), map_region=map_region)


Expand Down Expand Up @@ -624,8 +619,7 @@ def get_base1(map_region=1,
else:
ax = fig.gca()

print(Basemap)
m = Basemap(**map_par_sd)
m = basemap.Basemap(**map_par_sd)
print('getting base1')
print(m)
plt.axes(ax) # make sure axes ax are current
Expand Down Expand Up @@ -677,7 +671,7 @@ def get_base2(**kwargs):
mr = {'map_region': reg}
mp, fig_par = map_regions(map_region=reg)

m = Basemap(**mp)
m = basemap.Basemap(**mp)
ax = fig.add_axes(
[0.1, 0.1, 0.7, 0.7]) # need to change back to [0.1,0.1,0.7,0.7]
plt.axes(ax) # make the original axes current again
Expand Down Expand Up @@ -731,7 +725,7 @@ def get_base3(**kwargs):
# fig = plt.figure(1, figsize=(57.625, 43.75))
# Use map_regions function to define input paramters for Basemap
mp, fig_par = map_regions(map_region)
m = Basemap(**mp)
m = basemap.Basemap(**mp)
ax = fig.add_axes(
[0, 0, 1, 1],
frameon=False) # need to change back to [0.1,0.1,0.7,0.7]
Expand Down Expand Up @@ -787,7 +781,7 @@ def get_base_image(imagefile, **kwargs):
mr = {'map_region': reg}
mp, fig_par = map_regions(map_region=reg)

m = Basemap(**mp)
m = basemap.Basemap(**mp)
ax = fig.add_axes(
[0.1, 0.1, 0.7, 0.7]) # need to change back to [0.1,0.1,0.7,0.7]
plt.axes(ax) # make the original axes current again
Expand Down Expand Up @@ -910,7 +904,7 @@ def plot_track(lon, lat,
fig = FIGURE.fig
m = FIGURE.m
ax = FIGURE.ax
nullfmt = NullFormatter()
nullfmt = mpl.ticker.NullFormatter()
if ax is None:
az = plt.gca()
pos = az.get_position()
Expand Down Expand Up @@ -974,9 +968,9 @@ def plot_track(lon, lat,
points = np.array([cx, cy]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)

lc = LineCollection(segments, cmap=plt.get_cmap('jet'),
norm=plt.Normalize(zlevel.min(),
zlevel.max()))
lc = mpl.collections.LineCollection(
segments, cmap=plt.get_cmap('jet'),
norm=plt.Normalize(zlevel.min(), zlevel.max()))
lc.set_array(zlevel.flatten())
lc.set_linewidth(3)
plt.gca().add_collection(lc)
Expand All @@ -985,9 +979,9 @@ def plot_track(lon, lat,
points = np.array([cx[i, :], cy[i, :]]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]],
axis=1)
lc = LineCollection(segments, cmap=plt.get_cmap('jet'),
norm=plt.Normalize(zlevel.min(),
zlevel.max()))
lc = mpl.collections.LineCollection(
segments, cmap=plt.get_cmap('jet'),
norm=plt.Normalize(zlevel.min(), zlevel.max()))
lc.set_array(zlevel.flatten())
lc.set_linewidth(3)
plt.gca().add_collection(lc)
Expand Down Expand Up @@ -1086,13 +1080,11 @@ def plot_grid(D, map_region='polarcat', dres=0.5,
a = 6378.273e3
ec = 0.081816153
b = a * np.sqrt(1. - ec ** 2)
m = Basemap(projection='stere', lat_0=90, lon_0=-45, lat_ts=70,
llcrnrlat=33.92, llcrnrlon=279.96,
urcrnrlon=102.34, urcrnrlat=31.37,
rsphere=(a, b))
# m = Basemap(width=12000000,height=8000000,
# resolution='l',projection='npstere',
# lat_ts=50,lat_0=50,lon_0=-107.)#Set up a basemap
m = basemap.Basemap(
projection='stere', lat_0=90, lon_0=-45, lat_ts=70,
llcrnrlat=33.92, llcrnrlon=279.96,
urcrnrlon=102.34, urcrnrlat=31.37,
rsphere=(a, b))
m.drawcoastlines()

if points:
Expand Down

0 comments on commit 41c18c0

Please sign in to comment.